Added epoch counter to restored variables
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -190,6 +190,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||
beta1=0.5, beta2=0.999),
|
||||
# global step counter
|
||||
'epoch_var': k.variable(-1, dtype=tf.int64),
|
||||
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
||||
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
||||
'global_step_xd': k.variable(0, dtype=tf.int64),
|
||||
@ -204,15 +205,24 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
checkpoint = tf.train.Checkpoint(**checkpointables)
|
||||
checkpoint.restore(latest_checkpoint)
|
||||
|
||||
for epoch in range(train_epoch):
|
||||
outputs = _train_one_epoch(epoch, dataset, targets_real=y_real,
|
||||
def _get_last_epoch(epoch_var: tf.Variable, **kwargs) -> int:
|
||||
return int(epoch_var)
|
||||
|
||||
last_epoch = _get_last_epoch(**checkpointables)
|
||||
previous_epochs = 0
|
||||
if last_epoch != -1:
|
||||
previous_epochs = last_epoch + 1
|
||||
|
||||
for epoch in range(train_epoch - previous_epochs):
|
||||
_epoch = epoch + previous_epochs
|
||||
outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real,
|
||||
targets_fake=y_fake, z_generator=z_generator,
|
||||
verbose=verbose,
|
||||
**checkpointables)
|
||||
|
||||
if verbose:
|
||||
print((
|
||||
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
||||
f"[{_epoch + 1:d}/{train_epoch:d}] - "
|
||||
f"train time: {outputs['per_epoch_time']:.2f}, "
|
||||
f"Decoder loss: {outputs['decoder_loss']:.3f}, "
|
||||
f"X Discriminator loss: {outputs['xd_loss']:.3f}, "
|
||||
@ -222,7 +232,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
))
|
||||
|
||||
# save sample image summary
|
||||
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable) -> None:
|
||||
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
|
||||
resultsample = decoder(sample).cpu()
|
||||
grid = prepare_image(resultsample)
|
||||
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
||||
@ -293,8 +303,10 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
enc_dec_optimizer: tf.train.Optimizer,
|
||||
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
|
||||
global_step_decoder: tf.Variable,
|
||||
global_step_enc_dec: tf.Variable) -> Dict[str, float]:
|
||||
global_step_enc_dec: tf.Variable,
|
||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||
|
||||
epoch_var.assign(epoch)
|
||||
epoch_start_time = time.time()
|
||||
# define loss variables
|
||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||
|
||||
Reference in New Issue
Block a user