Fixed frequency of summary savings
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -314,7 +314,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
global_step_enc_dec: tf.Variable,
|
global_step_enc_dec: tf.Variable,
|
||||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||||
|
|
||||||
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY):
|
with summary_ops_v2.always_record_summaries():
|
||||||
epoch_var.assign(epoch)
|
epoch_var.assign(epoch)
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
# define loss variables
|
# define loss variables
|
||||||
@ -338,7 +338,8 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
inputs=x,
|
inputs=x,
|
||||||
targets_real=targets_real,
|
targets_real=targets_real,
|
||||||
targets_fake=targets_fake,
|
targets_fake=targets_fake,
|
||||||
global_step=global_step_xd,
|
global_step_xd=global_step_xd,
|
||||||
|
global_step=global_step,
|
||||||
z_generator=z_generator)
|
z_generator=z_generator)
|
||||||
xd_loss_avg(_xd_train_loss)
|
xd_loss_avg(_xd_train_loss)
|
||||||
|
|
||||||
@ -348,7 +349,8 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
x_discriminator=x_discriminator,
|
x_discriminator=x_discriminator,
|
||||||
optimizer=decoder_optimizer,
|
optimizer=decoder_optimizer,
|
||||||
targets=targets_real,
|
targets=targets_real,
|
||||||
global_step=global_step_decoder,
|
global_step_decoder=global_step_decoder,
|
||||||
|
global_step=global_step,
|
||||||
z_generator=z_generator)
|
z_generator=z_generator)
|
||||||
decoder_loss_avg(_decoder_train_loss)
|
decoder_loss_avg(_decoder_train_loss)
|
||||||
|
|
||||||
@ -360,7 +362,8 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
inputs=x,
|
inputs=x,
|
||||||
targets_real=targets_real,
|
targets_real=targets_real,
|
||||||
targets_fake=targets_fake,
|
targets_fake=targets_fake,
|
||||||
global_step=global_step_zd,
|
global_step_zd=global_step_zd,
|
||||||
|
global_step=global_step,
|
||||||
z_generator=z_generator)
|
z_generator=z_generator)
|
||||||
zd_loss_avg(_zd_train_loss)
|
zd_loss_avg(_zd_train_loss)
|
||||||
|
|
||||||
@ -372,15 +375,17 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
optimizer=enc_dec_optimizer,
|
optimizer=enc_dec_optimizer,
|
||||||
inputs=x,
|
inputs=x,
|
||||||
targets=targets_real,
|
targets=targets_real,
|
||||||
global_step=global_step_enc_dec)
|
global_step_enc_dec=global_step_enc_dec,
|
||||||
|
global_step=global_step)
|
||||||
enc_dec_loss_avg(reconstruction_loss)
|
enc_dec_loss_avg(reconstruction_loss)
|
||||||
encoder_loss_avg(encoder_loss)
|
encoder_loss_avg(encoder_loss)
|
||||||
|
|
||||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
grid = prepare_image(comparison.cpu(), nrow=64)
|
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||||
summary_ops_v2.image(name='reconstruction',
|
grid = prepare_image(comparison.cpu(), nrow=64)
|
||||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
summary_ops_v2.image(name='reconstruction',
|
||||||
step=global_step)
|
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
||||||
|
step=global_step)
|
||||||
global_step.assign_add(1)
|
global_step.assign_add(1)
|
||||||
|
|
||||||
epoch_end_time = time.time()
|
epoch_end_time = time.time()
|
||||||
@ -403,6 +408,7 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder
|
|||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||||
|
global_step_xd: tf.Variable,
|
||||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Trains the x discriminator model for one step (one batch).
|
Trains the x discriminator model for one step (one batch).
|
||||||
@ -414,6 +420,7 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder
|
|||||||
:param targets_real: target tensor for real loss calculation
|
:param targets_real: target tensor for real loss calculation
|
||||||
:param targets_fake: target tensor for fake loss calculation
|
:param targets_fake: target tensor for fake loss calculation
|
||||||
:param global_step: the global step variable
|
:param global_step: the global step variable
|
||||||
|
:param global_step_xd: global step variable for xd
|
||||||
:param z_generator: callable function that returns a z variable
|
:param z_generator: callable function that returns a z variable
|
||||||
:return: the calculated loss
|
:return: the calculated loss
|
||||||
"""
|
"""
|
||||||
@ -428,15 +435,16 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder
|
|||||||
|
|
||||||
_xd_train_loss = xd_real_loss + xd_fake_loss
|
_xd_train_loss = xd_real_loss + xd_fake_loss
|
||||||
|
|
||||||
summary_ops_v2.scalar(name='x_discriminator_real_loss', tensor=xd_real_loss,
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='x_discriminator_real_loss', tensor=xd_real_loss,
|
||||||
summary_ops_v2.scalar(name='x_discriminator_fake_loss', tensor=xd_fake_loss,
|
step=global_step)
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='x_discriminator_fake_loss', tensor=xd_fake_loss,
|
||||||
summary_ops_v2.scalar(name='x_discriminator_loss', tensor=_xd_train_loss,
|
step=global_step)
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='x_discriminator_loss', tensor=_xd_train_loss,
|
||||||
|
step=global_step)
|
||||||
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
|
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
|
optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
|
||||||
global_step=global_step)
|
global_step=global_step_xd)
|
||||||
|
|
||||||
return _xd_train_loss
|
return _xd_train_loss
|
||||||
|
|
||||||
@ -444,6 +452,7 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder
|
|||||||
def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
targets: tf.Tensor, global_step: tf.Variable,
|
targets: tf.Tensor, global_step: tf.Variable,
|
||||||
|
global_step_decoder: tf.Variable,
|
||||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Trains the decoder model for one step (one batch).
|
Trains the decoder model for one step (one batch).
|
||||||
@ -453,6 +462,7 @@ def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
|||||||
:param optimizer: instance of chosen optimizer
|
:param optimizer: instance of chosen optimizer
|
||||||
:param targets: target tensor for loss calculation
|
:param targets: target tensor for loss calculation
|
||||||
:param global_step: the global step variable
|
:param global_step: the global step variable
|
||||||
|
:param global_step_decoder: global step variable for decoder
|
||||||
:param z_generator: callable function that returns a z variable
|
:param z_generator: callable function that returns a z variable
|
||||||
:return: the calculated loss
|
:return: the calculated loss
|
||||||
"""
|
"""
|
||||||
@ -463,11 +473,12 @@ def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
|||||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||||
_decoder_train_loss = binary_crossentropy(targets, xd_result)
|
_decoder_train_loss = binary_crossentropy(targets, xd_result)
|
||||||
|
|
||||||
summary_ops_v2.scalar(name='decoder_loss', tensor=_decoder_train_loss,
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='decoder_loss', tensor=_decoder_train_loss,
|
||||||
|
step=global_step)
|
||||||
grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(grads, decoder.trainable_variables),
|
optimizer.apply_gradients(zip(grads, decoder.trainable_variables),
|
||||||
global_step=global_step)
|
global_step=global_step_decoder)
|
||||||
|
|
||||||
return _decoder_train_loss
|
return _decoder_train_loss
|
||||||
|
|
||||||
@ -476,6 +487,7 @@ def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder
|
|||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||||
|
global_step_zd: tf.Variable,
|
||||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Trains the z discriminator one step (one batch).
|
Trains the z discriminator one step (one batch).
|
||||||
@ -487,6 +499,7 @@ def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder
|
|||||||
:param targets_real: target tensor for real loss calculation
|
:param targets_real: target tensor for real loss calculation
|
||||||
:param targets_fake: target tensor for fake loss calculation
|
:param targets_fake: target tensor for fake loss calculation
|
||||||
:param global_step: the global step variable
|
:param global_step: the global step variable
|
||||||
|
:param global_step_zd: global step variable for zd
|
||||||
:param z_generator: callable function that returns a z variable
|
:param z_generator: callable function that returns a z variable
|
||||||
:return: the calculated loss
|
:return: the calculated loss
|
||||||
"""
|
"""
|
||||||
@ -502,22 +515,24 @@ def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder
|
|||||||
|
|
||||||
_zd_train_loss = zd_real_loss + zd_fake_loss
|
_zd_train_loss = zd_real_loss + zd_fake_loss
|
||||||
|
|
||||||
summary_ops_v2.scalar(name='z_discriminator_real_loss', tensor=zd_real_loss,
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='z_discriminator_real_loss', tensor=zd_real_loss,
|
||||||
summary_ops_v2.scalar(name='z_discriminator_fake_loss', tensor=zd_fake_loss,
|
step=global_step)
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='z_discriminator_fake_loss', tensor=zd_fake_loss,
|
||||||
summary_ops_v2.scalar(name='z_discriminator_loss', tensor=_zd_train_loss,
|
step=global_step)
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='z_discriminator_loss', tensor=_zd_train_loss,
|
||||||
|
step=global_step)
|
||||||
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
|
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
|
optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
|
||||||
global_step=global_step)
|
global_step=global_step_zd)
|
||||||
|
|
||||||
return _zd_train_loss
|
return _zd_train_loss
|
||||||
|
|
||||||
|
|
||||||
def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
||||||
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
||||||
targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
targets: tf.Tensor, global_step: tf.Variable,
|
||||||
|
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||||
"""
|
"""
|
||||||
Trains the encoder and decoder jointly for one step (one batch).
|
Trains the encoder and decoder jointly for one step (one batch).
|
||||||
|
|
||||||
@ -528,6 +543,7 @@ def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDi
|
|||||||
:param inputs: inputs from dataset
|
:param inputs: inputs from dataset
|
||||||
:param targets: target tensor for loss calculation
|
:param targets: target tensor for loss calculation
|
||||||
:param global_step: the global step variable
|
:param global_step: the global step variable
|
||||||
|
:param global_step_enc_dec: global step variable for enc_dec
|
||||||
:return: tuple of encoder loss, reconstruction loss, reconstructed input
|
:return: tuple of encoder loss, reconstruction loss, reconstructed input
|
||||||
"""
|
"""
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
@ -539,17 +555,18 @@ def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDi
|
|||||||
reconstruction_loss = binary_crossentropy(inputs, x_decoded)
|
reconstruction_loss = binary_crossentropy(inputs, x_decoded)
|
||||||
_enc_dec_train_loss = encoder_loss + reconstruction_loss
|
_enc_dec_train_loss = encoder_loss + reconstruction_loss
|
||||||
|
|
||||||
summary_ops_v2.scalar(name='encoder_loss', tensor=encoder_loss,
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='encoder_loss', tensor=encoder_loss,
|
||||||
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
|
step=global_step)
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
|
||||||
summary_ops_v2.scalar(name='encoder_decoder_loss', tensor=_enc_dec_train_loss,
|
step=global_step)
|
||||||
step=global_step)
|
summary_ops_v2.scalar(name='encoder_decoder_loss', tensor=_enc_dec_train_loss,
|
||||||
|
step=global_step)
|
||||||
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
||||||
encoder.trainable_variables + decoder.trainable_variables)
|
encoder.trainable_variables + decoder.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(enc_dec_grads,
|
optimizer.apply_gradients(zip(enc_dec_grads,
|
||||||
encoder.trainable_variables + decoder.trainable_variables),
|
encoder.trainable_variables + decoder.trainable_variables),
|
||||||
global_step=global_step)
|
global_step=global_step_enc_dec)
|
||||||
|
|
||||||
return encoder_loss, reconstruction_loss, x_decoded
|
return encoder_loss, reconstruction_loss, x_decoded
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user