From 682e11b435cf7d7e1b62eec639651ab3c614a825 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 12:41:10 +0100 Subject: [PATCH] Added step values to image summaries Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 50d755c..cdddfdf 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -192,7 +192,9 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i os.makedirs(directory) comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) grid = prepare_image(comparison.cpu(), nrow=64) - summary_ops_v2.image(name='reconstruction_' + str(epoch), tensor=k.expand_dims(grid, axis=0), max_images=1) + summary_ops_v2.image(name='reconstruction_' + str(epoch), + tensor=k.expand_dims(grid, axis=0), max_images=1, + step=global_step_decoder) from PIL import Image filename = 'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png' ndarr = grid.cpu().numpy() @@ -220,7 +222,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i directory = 'results' + str(inlier_classes[0]) os.makedirs(directory, exist_ok=True) grid = prepare_image(resultsample) - summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), max_images=1) + summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), + max_images=1, step=global_step_decoder) from PIL import Image filename = 'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png' ndarr = grid.cpu().numpy() @@ -411,7 +414,7 @@ def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable: if __name__ == "__main__": tf.enable_eager_execution() inlier_classes = [0] - iteration = 1 + iteration = 2 train_summary_writer = summary_ops_v2.create_file_writer( './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():