Removed sample generation
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -59,8 +59,8 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
epoch as well as after finishing training (or stopping early). When starting
|
epoch as well as after finishing training (or stopping early). When starting
|
||||||
this function with the same ``iteration`` then the training will try to
|
this function with the same ``iteration`` then the training will try to
|
||||||
continue where it ended last time by restoring a saved checkpoint.
|
continue where it ended last time by restoring a saved checkpoint.
|
||||||
The loss values are provided as scalar summaries. Reconstruction and sample
|
The loss values are provided as scalar summaries. Reconstruction images are
|
||||||
images are provided as summary images.
|
provided as summary images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: train dataset
|
dataset: train dataset
|
||||||
@ -73,9 +73,6 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
verbose: if True prints train progress info to console (default: True)
|
verbose: if True prints train progress info to console (default: True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# non-preserved tensors
|
|
||||||
sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1)
|
|
||||||
|
|
||||||
# checkpointed tensors and variables
|
# checkpointed tensors and variables
|
||||||
checkpointables = {
|
checkpointables = {
|
||||||
'learning_rate_var': K.variable(lr),
|
'learning_rate_var': K.variable(lr),
|
||||||
@ -126,16 +123,6 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}"
|
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}"
|
||||||
))
|
))
|
||||||
|
|
||||||
# save sample image summary
|
|
||||||
def _save_sample(decoder: model.Decoder, global_step: tf.Variable, **kwargs) -> None:
|
|
||||||
resultsample = decoder(sample).cpu()
|
|
||||||
grid = util.prepare_image(resultsample)
|
|
||||||
summary_ops_v2.image(name='sample', tensor=K.expand_dims(grid, axis=0),
|
|
||||||
max_images=1, step=global_step)
|
|
||||||
|
|
||||||
with summary_ops_v2.always_record_summaries():
|
|
||||||
_save_sample(**checkpointables)
|
|
||||||
|
|
||||||
# save weights at end of epoch
|
# save weights at end of epoch
|
||||||
checkpoint.save(checkpoint_prefix)
|
checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user