From a7f9d586f4f724bb653e4fe5550f5bbc7436eac4 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Wed, 10 Apr 2019 15:18:39 +0200 Subject: [PATCH] Fixed loop through data set Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index f6bdc2f..eeeec6a 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -160,8 +160,8 @@ def _train_one_epoch_simple(epoch: int, step=global_step) if verbose: print("learning rate change!") - - for x, _ in dataset: + + for x in dataset: reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder, decoder=decoder, optimizer=enc_dec_optimizer, @@ -235,12 +235,12 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, if __name__ == "__main__": from twomartens.masterthesis.aae.data import prepare_training_data tf.enable_eager_execution() - inlier_classes = [3] - iteration = 1 + inlier_classes = [8] + iteration = 2 train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=inlier_classes, total_classes=10) train_summary_writer = summary_ops_v2.create_file_writer( './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) with train_summary_writer.as_default(): - train(dataset=train_dataset, iteration=iteration, - weights_prefix='weights/' + str(inlier_classes[0]) + '/') + train_simple(dataset=train_dataset, iteration=iteration, + weights_prefix='weights/' + str(inlier_classes[0]) + '/')