Fixed coding style

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-25 14:32:08 +01:00
parent 9a08ea3bb9
commit 6835c284af

View File

@ -463,24 +463,26 @@ def train(dataset: tf.data.Dataset,
} }
checkpointables.update({ checkpointables.update({
# get models # get models
'encoder': model.Encoder(zsize), 'encoder': model.Encoder(zsize),
'decoder': model.Decoder(channels), 'decoder': model.Decoder(channels),
'z_discriminator': model.ZDiscriminator(), 'z_discriminator': model.ZDiscriminator(),
'x_discriminator': model.XDiscriminator(), 'x_discriminator': model.XDiscriminator(),
# define optimizers # define optimizers
'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), 'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999),
'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], 'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),
'x_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], 'x_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999), beta1=0.5, beta2=0.999),
# global step counter # global step counter
'epoch_var': K.variable(-1, dtype=tf.int64), 'epoch_var': K.variable(-1, dtype=tf.int64),
'global_step': tf.train.get_or_create_global_step(), 'global_step': tf.train.get_or_create_global_step(),
'global_step_decoder': K.variable(0, dtype=tf.int64), 'global_step_decoder': K.variable(0, dtype=tf.int64),
'global_step_enc_dec': K.variable(0, dtype=tf.int64), 'global_step_enc_dec': K.variable(0, dtype=tf.int64),
'global_step_xd': K.variable(0, dtype=tf.int64), 'global_step_xd': K.variable(0, dtype=tf.int64),
'global_step_zd': K.variable(0, dtype=tf.int64), 'global_step_zd': K.variable(0, dtype=tf.int64),
}) })
# checkpoint # checkpoint
@ -539,7 +541,7 @@ def train(dataset: tf.data.Dataset,
strike = False strike = False
total_strike = False total_strike = False
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
outputs['xd_loss'] + outputs['zd_loss'] outputs['xd_loss'] + outputs['zd_loss']
if total_loss < total_lowest_loss: if total_loss < total_lowest_loss:
total_lowest_loss = total_loss total_lowest_loss = total_loss
elif total_loss > TOTAL_LOSS_GRACE_CAP: elif total_loss > TOTAL_LOSS_GRACE_CAP:
@ -606,7 +608,6 @@ def _train_one_epoch(epoch: int,
global_step_decoder: tf.Variable, global_step_decoder: tf.Variable,
global_step_enc_dec: tf.Variable, global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]: epoch_var: tf.Variable) -> Dict[str, float]:
with summary_ops_v2.always_record_summaries(): with summary_ops_v2.always_record_summaries():
epoch_var.assign(epoch) epoch_var.assign(epoch)
epoch_start_time = time.time() epoch_start_time = time.time()
@ -688,11 +689,11 @@ def _train_one_epoch(epoch: int,
# final losses of epoch # final losses of epoch
outputs = { outputs = {
'decoder_loss': decoder_loss_avg.result(False), 'decoder_loss': decoder_loss_avg.result(False),
'encoder_loss': encoder_loss_avg.result(False), 'encoder_loss': encoder_loss_avg.result(False),
'enc_dec_loss': enc_dec_loss_avg.result(False), 'enc_dec_loss': enc_dec_loss_avg.result(False),
'xd_loss': xd_loss_avg.result(False), 'xd_loss': xd_loss_avg.result(False),
'zd_loss': zd_loss_avg.result(False), 'zd_loss': zd_loss_avg.result(False),
'per_epoch_time': per_epoch_time, 'per_epoch_time': per_epoch_time,
} }