Fixed sample image summary
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -194,6 +194,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
beta1=0.5, beta2=0.999),
|
||||
# global step counter
|
||||
'epoch_var': k.variable(-1, dtype=tf.int64),
|
||||
'global_step': k.variable(0, dtype=tf.int64),
|
||||
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
||||
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
||||
'global_step_xd': k.variable(0, dtype=tf.int64),
|
||||
@ -235,12 +236,14 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
))
|
||||
|
||||
# save sample image summary
|
||||
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
|
||||
def _save_sample(decoder: Decoder, global_step: tf.Variable, **kwargs) -> None:
|
||||
resultsample = decoder(sample).cpu()
|
||||
grid = prepare_image(resultsample)
|
||||
summary_ops_v2.image(name='sample', tensor=k.expand_dims(grid, axis=0),
|
||||
max_images=1, step=global_step_decoder)
|
||||
_save_sample(**checkpointables)
|
||||
max_images=1, step=global_step)
|
||||
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
_save_sample(**checkpointables)
|
||||
|
||||
# save weights at end of epoch
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
@ -305,13 +308,14 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
x_discriminator_optimizer: tf.train.Optimizer,
|
||||
z_discriminator_optimizer: tf.train.Optimizer,
|
||||
enc_dec_optimizer: tf.train.Optimizer,
|
||||
global_step: tf.Variable,
|
||||
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
|
||||
global_step_decoder: tf.Variable,
|
||||
global_step_enc_dec: tf.Variable,
|
||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||
|
||||
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY,
|
||||
global_step=global_step_decoder):
|
||||
global_step=global_step):
|
||||
epoch_var.assign(epoch)
|
||||
epoch_start_time = time.time()
|
||||
# define loss variables
|
||||
@ -382,6 +386,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
step=global_step_decoder)
|
||||
|
||||
batch_iteration.assign_add(1)
|
||||
global_step.assign_add(1)
|
||||
|
||||
epoch_end_time = time.time()
|
||||
per_epoch_time = epoch_end_time - epoch_start_time
|
||||
|
||||
Reference in New Issue
Block a user