Renamed internal functions to make them protected
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -88,10 +88,10 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
|||||||
# get dataset
|
# get dataset
|
||||||
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||||
train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size,
|
train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size,
|
||||||
drop_remainder=True).map(normalize)
|
drop_remainder=True).map(_normalize)
|
||||||
valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y))
|
valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y))
|
||||||
valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size,
|
valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size,
|
||||||
drop_remainder=True).map(normalize)
|
drop_remainder=True).map(_normalize)
|
||||||
|
|
||||||
return train_dataset, valid_dataset
|
return train_dataset, valid_dataset
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
y_fake = k.zeros(batch_size)
|
y_fake = k.zeros(batch_size)
|
||||||
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
||||||
# z generator function
|
# z generator function
|
||||||
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||||
|
|
||||||
# non-preserved python variables
|
# non-preserved python variables
|
||||||
encoder_lowest_loss = math.inf
|
encoder_lowest_loss = math.inf
|
||||||
@ -272,7 +272,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
batch_iteration = k.variable(0, dtype=tf.int64)
|
batch_iteration = k.variable(0, dtype=tf.int64)
|
||||||
for x, _ in dataset:
|
for x, _ in dataset:
|
||||||
# x discriminator
|
# x discriminator
|
||||||
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
optimizer=x_discriminator_optimizer,
|
optimizer=x_discriminator_optimizer,
|
||||||
inputs=x,
|
inputs=x,
|
||||||
@ -284,7 +284,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
|
|
||||||
# --------
|
# --------
|
||||||
# decoder
|
# decoder
|
||||||
_decoder_train_loss = train_decoder_step(decoder=decoder,
|
_decoder_train_loss = _train_decoder_step(decoder=decoder,
|
||||||
x_discriminator=x_discriminator,
|
x_discriminator=x_discriminator,
|
||||||
optimizer=decoder_optimizer,
|
optimizer=decoder_optimizer,
|
||||||
targets=targets_real,
|
targets=targets_real,
|
||||||
@ -294,7 +294,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
|
|
||||||
# ---------
|
# ---------
|
||||||
# z discriminator
|
# z discriminator
|
||||||
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
|
_zd_train_loss = _train_zdiscriminator_step(z_discriminator=z_discriminator,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
optimizer=z_discriminator_optimizer,
|
optimizer=z_discriminator_optimizer,
|
||||||
inputs=x,
|
inputs=x,
|
||||||
@ -306,7 +306,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
|
|
||||||
# -----------
|
# -----------
|
||||||
# encoder + decoder
|
# encoder + decoder
|
||||||
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
|
encoder_loss, reconstruction_loss, x_decoded = _train_enc_dec_step(encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
z_discriminator=z_discriminator,
|
z_discriminator=z_discriminator,
|
||||||
optimizer=enc_dec_optimizer,
|
optimizer=enc_dec_optimizer,
|
||||||
@ -349,7 +349,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||||
@ -385,7 +385,7 @@ def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
|||||||
return _xd_train_loss
|
return _xd_train_loss
|
||||||
|
|
||||||
|
|
||||||
def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
targets: tf.Tensor, global_step: tf.Variable,
|
targets: tf.Tensor, global_step: tf.Variable,
|
||||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||||
@ -414,7 +414,7 @@ def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
|||||||
return _decoder_train_loss
|
return _decoder_train_loss
|
||||||
|
|
||||||
|
|
||||||
def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||||
@ -451,7 +451,7 @@ def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
|||||||
return _zd_train_loss
|
return _zd_train_loss
|
||||||
|
|
||||||
|
|
||||||
def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
||||||
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
||||||
targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||||
"""
|
"""
|
||||||
@ -484,7 +484,7 @@ def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDis
|
|||||||
return encoder_loss, reconstruction_loss, x_decoded
|
return encoder_loss, reconstruction_loss, x_decoded
|
||||||
|
|
||||||
|
|
||||||
def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
def _get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
||||||
"""
|
"""
|
||||||
Creates and returns a z variable taken from a normal distribution.
|
Creates and returns a z variable taken from a normal distribution.
|
||||||
|
|
||||||
@ -496,7 +496,7 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
|||||||
return k.variable(z)
|
return k.variable(z)
|
||||||
|
|
||||||
|
|
||||||
def normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
def _normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
"""
|
"""
|
||||||
Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension.
|
Normalizes a tensor from a 0-255 range to a 0-1 range and adds one dimension.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user