Fixed loop through data set
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -160,8 +160,8 @@ def _train_one_epoch_simple(epoch: int,
|
||||
step=global_step)
|
||||
if verbose:
|
||||
print("learning rate change!")
|
||||
|
||||
for x, _ in dataset:
|
||||
|
||||
for x in dataset:
|
||||
reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder,
|
||||
decoder=decoder,
|
||||
optimizer=enc_dec_optimizer,
|
||||
@ -235,12 +235,12 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
|
||||
if __name__ == "__main__":
|
||||
from twomartens.masterthesis.aae.data import prepare_training_data
|
||||
tf.enable_eager_execution()
|
||||
inlier_classes = [3]
|
||||
iteration = 1
|
||||
inlier_classes = [8]
|
||||
iteration = 2
|
||||
train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=inlier_classes,
|
||||
total_classes=10)
|
||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
with train_summary_writer.as_default():
|
||||
train(dataset=train_dataset, iteration=iteration,
|
||||
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
||||
train_simple(dataset=train_dataset, iteration=iteration,
|
||||
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
||||
|
||||
Reference in New Issue
Block a user