Removed sample generation

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-04 17:28:27 +02:00
parent 17537b944b
commit 05fa489905

View File

@ -59,8 +59,8 @@ def train_simple(dataset: tf.data.Dataset,
epoch as well as after finishing training (or stopping early). When starting
this function with the same ``iteration`` then the training will try to
continue where it ended last time by restoring a saved checkpoint.
The loss values are provided as scalar summaries. Reconstruction and sample
images are provided as summary images.
The loss values are provided as scalar summaries. Reconstruction images are
provided as summary images.
Args:
dataset: train dataset
@ -73,9 +73,6 @@ def train_simple(dataset: tf.data.Dataset,
verbose: if True prints train progress info to console (default: True)
"""
# non-preserved tensors
sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1)
# checkpointed tensors and variables
checkpointables = {
'learning_rate_var': K.variable(lr),
@ -126,16 +123,6 @@ def train_simple(dataset: tf.data.Dataset,
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}"
))
# save sample image summary
def _save_sample(decoder: model.Decoder, global_step: tf.Variable, **kwargs) -> None:
resultsample = decoder(sample).cpu()
grid = util.prepare_image(resultsample)
summary_ops_v2.image(name='sample', tensor=K.expand_dims(grid, axis=0),
max_images=1, step=global_step)
with summary_ops_v2.always_record_summaries():
_save_sample(**checkpointables)
# save weights at end of epoch
checkpoint.save(checkpoint_prefix)