Fixed loop through data set

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-10 15:18:39 +02:00
parent da6e348edc
commit a7f9d586f4

View File

@ -161,7 +161,7 @@ def _train_one_epoch_simple(epoch: int,
if verbose:
print("learning rate change!")
for x, _ in dataset:
for x in dataset:
reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder,
decoder=decoder,
optimizer=enc_dec_optimizer,
@ -235,12 +235,12 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
if __name__ == "__main__":
from twomartens.masterthesis.aae.data import prepare_training_data
tf.enable_eager_execution()
inlier_classes = [3]
iteration = 1
inlier_classes = [8]
iteration = 2
train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=inlier_classes,
total_classes=10)
train_summary_writer = summary_ops_v2.create_file_writer(
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
with train_summary_writer.as_default():
train(dataset=train_dataset, iteration=iteration,
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
train_simple(dataset=train_dataset, iteration=iteration,
weights_prefix='weights/' + str(inlier_classes[0]) + '/')