Renamed internal functions to make them protected
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -88,10 +88,10 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
||||
# get dataset
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||
train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size,
|
||||
drop_remainder=True).map(normalize)
|
||||
drop_remainder=True).map(_normalize)
|
||||
valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y))
|
||||
valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size,
|
||||
drop_remainder=True).map(normalize)
|
||||
drop_remainder=True).map(_normalize)
|
||||
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
@ -120,7 +120,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
y_fake = k.zeros(batch_size)
|
||||
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
||||
# z generator function
|
||||
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||
z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||
|
||||
# non-preserved python variables
|
||||
encoder_lowest_loss = math.inf
|
||||
@ -272,7 +272,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
batch_iteration = k.variable(0, dtype=tf.int64)
|
||||
for x, _ in dataset:
|
||||
# x discriminator
|
||||
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||
decoder=decoder,
|
||||
optimizer=x_discriminator_optimizer,
|
||||
inputs=x,
|
||||
@ -284,7 +284,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
|
||||
# --------
|
||||
# decoder
|
||||
_decoder_train_loss = train_decoder_step(decoder=decoder,
|
||||
_decoder_train_loss = _train_decoder_step(decoder=decoder,
|
||||
x_discriminator=x_discriminator,
|
||||
optimizer=decoder_optimizer,
|
||||
targets=targets_real,
|
||||
@ -294,7 +294,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
|
||||
# ---------
|
||||
# z discriminator
|
||||
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
|
||||
_zd_train_loss = _train_zdiscriminator_step(z_discriminator=z_discriminator,
|
||||
encoder=encoder,
|
||||
optimizer=z_discriminator_optimizer,
|
||||
inputs=x,
|
||||
@ -306,7 +306,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
|
||||
# -----------
|
||||
# encoder + decoder
|
||||
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
|
||||
encoder_loss, reconstruction_loss, x_decoded = _train_enc_dec_step(encoder=encoder,
|
||||
decoder=decoder,
|
||||
z_discriminator=z_discriminator,
|
||||
optimizer=enc_dec_optimizer,
|
||||
@ -349,7 +349,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
return outputs
|
||||
|
||||
|
||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||
def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||
optimizer: tf.train.Optimizer,
|
||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||
@ -385,7 +385,7 @@ def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||
return _xd_train_loss
|
||||
|
||||
|
||||
def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
||||
def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
||||
optimizer: tf.train.Optimizer,
|
||||
targets: tf.Tensor, global_step: tf.Variable,
|
||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||
@ -414,7 +414,7 @@ def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
||||
return _decoder_train_loss
|
||||
|
||||
|
||||
def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
||||
def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
||||
optimizer: tf.train.Optimizer,
|
||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||
@ -451,7 +451,7 @@ def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
||||
return _zd_train_loss
|
||||
|
||||
|
||||
def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
||||
def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
||||
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
||||
targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
@ -484,7 +484,7 @@ def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDis
|
||||
return encoder_loss, reconstruction_loss, x_decoded
|
||||
|
||||
|
||||
def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
||||
def _get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
||||
"""
|
||||
Creates and returns a z variable taken from a normal distribution.
|
||||
|
||||
@ -496,7 +496,7 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
||||
return k.variable(z)
|
||||
|
||||
|
||||
def normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
def _normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user