Removed sample generation
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -59,8 +59,8 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
epoch as well as after finishing training (or stopping early). When starting
|
||||
this function with the same ``iteration`` then the training will try to
|
||||
continue where it ended last time by restoring a saved checkpoint.
|
||||
The loss values are provided as scalar summaries. Reconstruction and sample
|
||||
images are provided as summary images.
|
||||
The loss values are provided as scalar summaries. Reconstruction images are
|
||||
provided as summary images.
|
||||
|
||||
Args:
|
||||
dataset: train dataset
|
||||
@ -73,9 +73,6 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
verbose: if True prints train progress info to console (default: True)
|
||||
"""
|
||||
|
||||
# non-preserved tensors
|
||||
sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1)
|
||||
|
||||
# checkpointed tensors and variables
|
||||
checkpointables = {
|
||||
'learning_rate_var': K.variable(lr),
|
||||
@ -126,16 +123,6 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}"
|
||||
))
|
||||
|
||||
# save sample image summary
|
||||
def _save_sample(decoder: model.Decoder, global_step: tf.Variable, **kwargs) -> None:
|
||||
resultsample = decoder(sample).cpu()
|
||||
grid = util.prepare_image(resultsample)
|
||||
summary_ops_v2.image(name='sample', tensor=K.expand_dims(grid, axis=0),
|
||||
max_images=1, step=global_step)
|
||||
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
_save_sample(**checkpointables)
|
||||
|
||||
# save weights at end of epoch
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user