diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 06bf50e..fab56e2 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -447,7 +447,7 @@ def train(dataset: tf.data.Dataset, sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1) # z generator function z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize) - + # non-preserved python variables encoder_lowest_loss = math.inf decoder_lowest_loss = math.inf @@ -463,24 +463,26 @@ def train(dataset: tf.data.Dataset, } checkpointables.update({ # get models - 'encoder': model.Encoder(zsize), - 'decoder': model.Decoder(channels), - 'z_discriminator': model.ZDiscriminator(), - 'x_discriminator': model.XDiscriminator(), + 'encoder': model.Encoder(zsize), + 'decoder': model.Decoder(channels), + 'z_discriminator': model.ZDiscriminator(), + 'x_discriminator': model.XDiscriminator(), # define optimizers - 'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), - 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), + 'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + beta1=0.5, beta2=0.999), + 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + beta1=0.5, beta2=0.999), 'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), 'x_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), # global step counter - 'epoch_var': K.variable(-1, dtype=tf.int64), - 'global_step': tf.train.get_or_create_global_step(), - 'global_step_decoder': K.variable(0, dtype=tf.int64), - 'global_step_enc_dec': K.variable(0, dtype=tf.int64), - 'global_step_xd': K.variable(0, dtype=tf.int64), - 'global_step_zd': K.variable(0, dtype=tf.int64), + 'epoch_var': K.variable(-1, dtype=tf.int64), + 'global_step': tf.train.get_or_create_global_step(), + 'global_step_decoder': K.variable(0, dtype=tf.int64), + 'global_step_enc_dec': K.variable(0, dtype=tf.int64), + 'global_step_xd': K.variable(0, dtype=tf.int64), + 'global_step_zd': K.variable(0, dtype=tf.int64), }) # checkpoint @@ -533,13 +535,13 @@ def train(dataset: tf.data.Dataset, # save weights at end of epoch checkpoint.save(checkpoint_prefix) - + # check for improvements in error reduction - otherwise early stopping if early_stopping: strike = False total_strike = False total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ - outputs['xd_loss'] + outputs['zd_loss'] + outputs['xd_loss'] + outputs['zd_loss'] if total_loss < total_lowest_loss: total_lowest_loss = total_loss elif total_loss > TOTAL_LOSS_GRACE_CAP: @@ -571,16 +573,16 @@ def train(dataset: tf.data.Dataset, pass else: grace_period = GRACE - + if grace_period == 0: break - + if verbose: if grace_period > 0: print("Training finish!... save model weights") if grace_period == 0: print("Training stopped early!... save model weights") - + # save trained models checkpoint.save(checkpoint_prefix) @@ -606,7 +608,6 @@ def _train_one_epoch(epoch: int, global_step_decoder: tf.Variable, global_step_enc_dec: tf.Variable, epoch_var: tf.Variable) -> Dict[str, float]: - with summary_ops_v2.always_record_summaries(): epoch_var.assign(epoch) epoch_start_time = time.time() @@ -688,11 +689,11 @@ def _train_one_epoch(epoch: int, # final losses of epoch outputs = { - 'decoder_loss': decoder_loss_avg.result(False), - 'encoder_loss': encoder_loss_avg.result(False), - 'enc_dec_loss': enc_dec_loss_avg.result(False), - 'xd_loss': xd_loss_avg.result(False), - 'zd_loss': zd_loss_avg.result(False), + 'decoder_loss': decoder_loss_avg.result(False), + 'encoder_loss': encoder_loss_avg.result(False), + 'enc_dec_loss': enc_dec_loss_avg.result(False), + 'xd_loss': xd_loss_avg.result(False), + 'zd_loss': zd_loss_avg.result(False), 'per_epoch_time': per_epoch_time, }