Renamed internal functions to make them protected

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 16:38:58 +01:00
parent 483cb4eb3e
commit 028513b404

View File

@ -88,10 +88,10 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
# get dataset # get dataset
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y)) train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size,
drop_remainder=True).map(normalize) drop_remainder=True).map(_normalize)
valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y)) valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y))
valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size, valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size,
drop_remainder=True).map(normalize) drop_remainder=True).map(_normalize)
return train_dataset, valid_dataset return train_dataset, valid_dataset
@ -120,7 +120,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
y_fake = k.zeros(batch_size) y_fake = k.zeros(batch_size)
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1) sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
# z generator function # z generator function
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize) z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize)
# non-preserved python variables # non-preserved python variables
encoder_lowest_loss = math.inf encoder_lowest_loss = math.inf
@ -272,7 +272,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
batch_iteration = k.variable(0, dtype=tf.int64) batch_iteration = k.variable(0, dtype=tf.int64)
for x, _ in dataset: for x, _ in dataset:
# x discriminator # x discriminator
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator, _xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
decoder=decoder, decoder=decoder,
optimizer=x_discriminator_optimizer, optimizer=x_discriminator_optimizer,
inputs=x, inputs=x,
@ -284,7 +284,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
# -------- # --------
# decoder # decoder
_decoder_train_loss = train_decoder_step(decoder=decoder, _decoder_train_loss = _train_decoder_step(decoder=decoder,
x_discriminator=x_discriminator, x_discriminator=x_discriminator,
optimizer=decoder_optimizer, optimizer=decoder_optimizer,
targets=targets_real, targets=targets_real,
@ -294,7 +294,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
# --------- # ---------
# z discriminator # z discriminator
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator, _zd_train_loss = _train_zdiscriminator_step(z_discriminator=z_discriminator,
encoder=encoder, encoder=encoder,
optimizer=z_discriminator_optimizer, optimizer=z_discriminator_optimizer,
inputs=x, inputs=x,
@ -306,7 +306,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
# ----------- # -----------
# encoder + decoder # encoder + decoder
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder, encoder_loss, reconstruction_loss, x_decoded = _train_enc_dec_step(encoder=encoder,
decoder=decoder, decoder=decoder,
z_discriminator=z_discriminator, z_discriminator=z_discriminator,
optimizer=enc_dec_optimizer, optimizer=enc_dec_optimizer,
@ -349,7 +349,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
return outputs return outputs
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor, inputs: tf.Tensor, targets_real: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable, targets_fake: tf.Tensor, global_step: tf.Variable,
@ -385,7 +385,7 @@ def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
return _xd_train_loss return _xd_train_loss
def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator, def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
targets: tf.Tensor, global_step: tf.Variable, targets: tf.Tensor, global_step: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor: z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
@ -414,7 +414,7 @@ def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
return _decoder_train_loss return _decoder_train_loss
def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder, def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor, inputs: tf.Tensor, targets_real: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable, targets_fake: tf.Tensor, global_step: tf.Variable,
@ -451,7 +451,7 @@ def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
return _zd_train_loss return _zd_train_loss
def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator, def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
optimizer: tf.train.Optimizer, inputs: tf.Tensor, optimizer: tf.train.Optimizer, inputs: tf.Tensor,
targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
""" """
@ -484,7 +484,7 @@ def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDis
return encoder_loss, reconstruction_loss, x_decoded return encoder_loss, reconstruction_loss, x_decoded
def get_z_variable(batch_size: int, zsize: int) -> tf.Variable: def _get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
""" """
Creates and returns a z variable taken from a normal distribution. Creates and returns a z variable taken from a normal distribution.
@ -496,7 +496,7 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
return k.variable(z) return k.variable(z)
def normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: def _normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
""" """
Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension. Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension.