@ -447,7 +447,7 @@ def train(dataset: tf.data.Dataset,
|
|||||||
sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1)
|
sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1)
|
||||||
# z generator function
|
# z generator function
|
||||||
z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize)
|
z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||||
|
|
||||||
# non-preserved python variables
|
# non-preserved python variables
|
||||||
encoder_lowest_loss = math.inf
|
encoder_lowest_loss = math.inf
|
||||||
decoder_lowest_loss = math.inf
|
decoder_lowest_loss = math.inf
|
||||||
@ -463,24 +463,26 @@ def train(dataset: tf.data.Dataset,
|
|||||||
}
|
}
|
||||||
checkpointables.update({
|
checkpointables.update({
|
||||||
# get models
|
# get models
|
||||||
'encoder': model.Encoder(zsize),
|
'encoder': model.Encoder(zsize),
|
||||||
'decoder': model.Decoder(channels),
|
'decoder': model.Decoder(channels),
|
||||||
'z_discriminator': model.ZDiscriminator(),
|
'z_discriminator': model.ZDiscriminator(),
|
||||||
'x_discriminator': model.XDiscriminator(),
|
'x_discriminator': model.XDiscriminator(),
|
||||||
# define optimizers
|
# define optimizers
|
||||||
'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
|
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
|
beta1=0.5, beta2=0.999),
|
||||||
'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
'x_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'x_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
# global step counter
|
# global step counter
|
||||||
'epoch_var': K.variable(-1, dtype=tf.int64),
|
'epoch_var': K.variable(-1, dtype=tf.int64),
|
||||||
'global_step': tf.train.get_or_create_global_step(),
|
'global_step': tf.train.get_or_create_global_step(),
|
||||||
'global_step_decoder': K.variable(0, dtype=tf.int64),
|
'global_step_decoder': K.variable(0, dtype=tf.int64),
|
||||||
'global_step_enc_dec': K.variable(0, dtype=tf.int64),
|
'global_step_enc_dec': K.variable(0, dtype=tf.int64),
|
||||||
'global_step_xd': K.variable(0, dtype=tf.int64),
|
'global_step_xd': K.variable(0, dtype=tf.int64),
|
||||||
'global_step_zd': K.variable(0, dtype=tf.int64),
|
'global_step_zd': K.variable(0, dtype=tf.int64),
|
||||||
})
|
})
|
||||||
|
|
||||||
# checkpoint
|
# checkpoint
|
||||||
@ -533,13 +535,13 @@ def train(dataset: tf.data.Dataset,
|
|||||||
|
|
||||||
# save weights at end of epoch
|
# save weights at end of epoch
|
||||||
checkpoint.save(checkpoint_prefix)
|
checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
# check for improvements in error reduction - otherwise early stopping
|
# check for improvements in error reduction - otherwise early stopping
|
||||||
if early_stopping:
|
if early_stopping:
|
||||||
strike = False
|
strike = False
|
||||||
total_strike = False
|
total_strike = False
|
||||||
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||||
outputs['xd_loss'] + outputs['zd_loss']
|
outputs['xd_loss'] + outputs['zd_loss']
|
||||||
if total_loss < total_lowest_loss:
|
if total_loss < total_lowest_loss:
|
||||||
total_lowest_loss = total_loss
|
total_lowest_loss = total_loss
|
||||||
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
||||||
@ -571,16 +573,16 @@ def train(dataset: tf.data.Dataset,
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
grace_period = GRACE
|
grace_period = GRACE
|
||||||
|
|
||||||
if grace_period == 0:
|
if grace_period == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
if grace_period > 0:
|
if grace_period > 0:
|
||||||
print("Training finish!... save model weights")
|
print("Training finish!... save model weights")
|
||||||
if grace_period == 0:
|
if grace_period == 0:
|
||||||
print("Training stopped early!... save model weights")
|
print("Training stopped early!... save model weights")
|
||||||
|
|
||||||
# save trained models
|
# save trained models
|
||||||
checkpoint.save(checkpoint_prefix)
|
checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
@ -606,7 +608,6 @@ def _train_one_epoch(epoch: int,
|
|||||||
global_step_decoder: tf.Variable,
|
global_step_decoder: tf.Variable,
|
||||||
global_step_enc_dec: tf.Variable,
|
global_step_enc_dec: tf.Variable,
|
||||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||||
|
|
||||||
with summary_ops_v2.always_record_summaries():
|
with summary_ops_v2.always_record_summaries():
|
||||||
epoch_var.assign(epoch)
|
epoch_var.assign(epoch)
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
@ -688,11 +689,11 @@ def _train_one_epoch(epoch: int,
|
|||||||
|
|
||||||
# final losses of epoch
|
# final losses of epoch
|
||||||
outputs = {
|
outputs = {
|
||||||
'decoder_loss': decoder_loss_avg.result(False),
|
'decoder_loss': decoder_loss_avg.result(False),
|
||||||
'encoder_loss': encoder_loss_avg.result(False),
|
'encoder_loss': encoder_loss_avg.result(False),
|
||||||
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
||||||
'xd_loss': xd_loss_avg.result(False),
|
'xd_loss': xd_loss_avg.result(False),
|
||||||
'zd_loss': zd_loss_avg.result(False),
|
'zd_loss': zd_loss_avg.result(False),
|
||||||
'per_epoch_time': per_epoch_time,
|
'per_epoch_time': per_epoch_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user