Fixed loop through data set

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-10 15:18:39 +02:00
parent da6e348edc
commit a7f9d586f4

View File

@ -160,8 +160,8 @@ def _train_one_epoch_simple(epoch: int,
step=global_step) step=global_step)
if verbose: if verbose:
print("learning rate change!") print("learning rate change!")
for x, _ in dataset: for x in dataset:
reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder, reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder,
decoder=decoder, decoder=decoder,
optimizer=enc_dec_optimizer, optimizer=enc_dec_optimizer,
@ -235,12 +235,12 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
if __name__ == "__main__": if __name__ == "__main__":
from twomartens.masterthesis.aae.data import prepare_training_data from twomartens.masterthesis.aae.data import prepare_training_data
tf.enable_eager_execution() tf.enable_eager_execution()
inlier_classes = [3] inlier_classes = [8]
iteration = 1 iteration = 2
train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=inlier_classes, train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=inlier_classes,
total_classes=10) total_classes=10)
train_summary_writer = summary_ops_v2.create_file_writer( train_summary_writer = summary_ops_v2.create_file_writer(
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
with train_summary_writer.as_default(): with train_summary_writer.as_default():
train(dataset=train_dataset, iteration=iteration, train_simple(dataset=train_dataset, iteration=iteration,
weights_prefix='weights/' + str(inlier_classes[0]) + '/') weights_prefix='weights/' + str(inlier_classes[0]) + '/')