diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 7d3de5f..8dfbea1 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -84,8 +84,7 @@ def train_simple(dataset: tf.data.Dataset, 'encoder': model.Encoder(zsize), 'decoder': model.Decoder(channels, zsize), # define optimizers - 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], - beta1=0.5, beta2=0.999), + 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var']), # global step counter 'epoch_var': K.variable(-1, dtype=tf.int64), 'global_step': tf.train.get_or_create_global_step(), @@ -201,7 +200,7 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, :param encoder: instance of encoder model :param decoder: instance of decoder model :param optimizer: instance of chosen optimizer - :param inputs: inputs from dataset + :param inputs: inputs from data set :param global_step: the global step variable :param global_step_enc_dec: global step variable for enc_dec :return: tuple of reconstruction loss, reconstructed input