Added epoch counter to restored variables
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -190,6 +190,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
# global step counter
|
# global step counter
|
||||||
|
'epoch_var': k.variable(-1, dtype=tf.int64),
|
||||||
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
||||||
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
||||||
'global_step_xd': k.variable(0, dtype=tf.int64),
|
'global_step_xd': k.variable(0, dtype=tf.int64),
|
||||||
@ -204,15 +205,24 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
checkpoint = tf.train.Checkpoint(**checkpointables)
|
checkpoint = tf.train.Checkpoint(**checkpointables)
|
||||||
checkpoint.restore(latest_checkpoint)
|
checkpoint.restore(latest_checkpoint)
|
||||||
|
|
||||||
for epoch in range(train_epoch):
|
def _get_last_epoch(epoch_var: tf.Variable, **kwargs) -> int:
|
||||||
outputs = _train_one_epoch(epoch, dataset, targets_real=y_real,
|
return int(epoch_var)
|
||||||
|
|
||||||
|
last_epoch = _get_last_epoch(**checkpointables)
|
||||||
|
previous_epochs = 0
|
||||||
|
if last_epoch != -1:
|
||||||
|
previous_epochs = last_epoch + 1
|
||||||
|
|
||||||
|
for epoch in range(train_epoch - previous_epochs):
|
||||||
|
_epoch = epoch + previous_epochs
|
||||||
|
outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real,
|
||||||
targets_fake=y_fake, z_generator=z_generator,
|
targets_fake=y_fake, z_generator=z_generator,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
**checkpointables)
|
**checkpointables)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print((
|
print((
|
||||||
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
f"[{_epoch + 1:d}/{train_epoch:d}] - "
|
||||||
f"train time: {outputs['per_epoch_time']:.2f}, "
|
f"train time: {outputs['per_epoch_time']:.2f}, "
|
||||||
f"Decoder loss: {outputs['decoder_loss']:.3f}, "
|
f"Decoder loss: {outputs['decoder_loss']:.3f}, "
|
||||||
f"X Discriminator loss: {outputs['xd_loss']:.3f}, "
|
f"X Discriminator loss: {outputs['xd_loss']:.3f}, "
|
||||||
@ -222,7 +232,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
))
|
))
|
||||||
|
|
||||||
# save sample image summary
|
# save sample image summary
|
||||||
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable) -> None:
|
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
|
||||||
resultsample = decoder(sample).cpu()
|
resultsample = decoder(sample).cpu()
|
||||||
grid = prepare_image(resultsample)
|
grid = prepare_image(resultsample)
|
||||||
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
||||||
@ -293,8 +303,10 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
enc_dec_optimizer: tf.train.Optimizer,
|
enc_dec_optimizer: tf.train.Optimizer,
|
||||||
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
|
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
|
||||||
global_step_decoder: tf.Variable,
|
global_step_decoder: tf.Variable,
|
||||||
global_step_enc_dec: tf.Variable) -> Dict[str, float]:
|
global_step_enc_dec: tf.Variable,
|
||||||
|
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||||
|
|
||||||
|
epoch_var.assign(epoch)
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
# define loss variables
|
# define loss variables
|
||||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||||
|
|||||||
Reference in New Issue
Block a user