Fixed sample image summary

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 18:58:19 +01:00
parent 3b9742a1b4
commit a2aa794d47

View File

@ -194,6 +194,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
beta1=0.5, beta2=0.999),
# global step counter
'epoch_var': k.variable(-1, dtype=tf.int64),
'global_step': k.variable(0, dtype=tf.int64),
'global_step_decoder': k.variable(0, dtype=tf.int64),
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
'global_step_xd': k.variable(0, dtype=tf.int64),
@ -235,12 +236,14 @@ def train(dataset: tf.data.Dataset, iteration: int,
))
# save sample image summary
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
def _save_sample(decoder: Decoder, global_step: tf.Variable, **kwargs) -> None:
resultsample = decoder(sample).cpu()
grid = prepare_image(resultsample)
summary_ops_v2.image(name='sample', tensor=k.expand_dims(grid, axis=0),
max_images=1, step=global_step_decoder)
_save_sample(**checkpointables)
max_images=1, step=global_step)
with summary_ops_v2.always_record_summaries():
_save_sample(**checkpointables)
# save weights at end of epoch
checkpoint.save(checkpoint_prefix)
@ -305,13 +308,14 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
x_discriminator_optimizer: tf.train.Optimizer,
z_discriminator_optimizer: tf.train.Optimizer,
enc_dec_optimizer: tf.train.Optimizer,
global_step: tf.Variable,
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
global_step_decoder: tf.Variable,
global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]:
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY,
global_step=global_step_decoder):
global_step=global_step):
epoch_var.assign(epoch)
epoch_start_time = time.time()
# define loss variables
@ -382,6 +386,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
step=global_step_decoder)
batch_iteration.assign_add(1)
global_step.assign_add(1)
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time