Added epoch counter to restored variables

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 18:12:02 +01:00
parent 1d4f56e40d
commit d9a35893af

View File

@ -190,6 +190,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], 'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),
# global step counter # global step counter
'epoch_var': k.variable(-1, dtype=tf.int64),
'global_step_decoder': k.variable(0, dtype=tf.int64), 'global_step_decoder': k.variable(0, dtype=tf.int64),
'global_step_enc_dec': k.variable(0, dtype=tf.int64), 'global_step_enc_dec': k.variable(0, dtype=tf.int64),
'global_step_xd': k.variable(0, dtype=tf.int64), 'global_step_xd': k.variable(0, dtype=tf.int64),
@ -204,15 +205,24 @@ def train(dataset: tf.data.Dataset, iteration: int,
checkpoint = tf.train.Checkpoint(**checkpointables) checkpoint = tf.train.Checkpoint(**checkpointables)
checkpoint.restore(latest_checkpoint) checkpoint.restore(latest_checkpoint)
for epoch in range(train_epoch): def _get_last_epoch(epoch_var: tf.Variable, **kwargs) -> int:
outputs = _train_one_epoch(epoch, dataset, targets_real=y_real, return int(epoch_var)
targets_fake=y_fake, z_generator= z_generator,
last_epoch = _get_last_epoch(**checkpointables)
previous_epochs = 0
if last_epoch != -1:
previous_epochs = last_epoch + 1
for epoch in range(train_epoch - previous_epochs):
_epoch = epoch + previous_epochs
outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real,
targets_fake=y_fake, z_generator=z_generator,
verbose=verbose, verbose=verbose,
**checkpointables) **checkpointables)
if verbose: if verbose:
print(( print((
f"[{epoch + 1:d}/{train_epoch:d}] - " f"[{_epoch + 1:d}/{train_epoch:d}] - "
f"train time: {outputs['per_epoch_time']:.2f}, " f"train time: {outputs['per_epoch_time']:.2f}, "
f"Decoder loss: {outputs['decoder_loss']:.3f}, " f"Decoder loss: {outputs['decoder_loss']:.3f}, "
f"X Discriminator loss: {outputs['xd_loss']:.3f}, " f"X Discriminator loss: {outputs['xd_loss']:.3f}, "
@ -222,7 +232,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
)) ))
# save sample image summary # save sample image summary
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable) -> None: def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
resultsample = decoder(sample).cpu() resultsample = decoder(sample).cpu()
grid = prepare_image(resultsample) grid = prepare_image(resultsample)
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
@ -293,8 +303,10 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
enc_dec_optimizer: tf.train.Optimizer, enc_dec_optimizer: tf.train.Optimizer,
global_step_xd: tf.Variable, global_step_zd: tf.Variable, global_step_xd: tf.Variable, global_step_zd: tf.Variable,
global_step_decoder: tf.Variable, global_step_decoder: tf.Variable,
global_step_enc_dec: tf.Variable) -> Dict[str, float]: global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]:
epoch_var.assign(epoch)
epoch_start_time = time.time() epoch_start_time = time.time()
# define loss variables # define loss variables
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)