From b173746784fa8bc4818eec339ecca542fa32c15b Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Wed, 10 Apr 2019 15:19:44 +0200 Subject: [PATCH] Ensured summaries are written Signed-off-by: Jim Martens --- src/twomartens/masterthesis/main.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 4e93695..8307b11 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -68,6 +68,7 @@ def _build_train(parser: argparse.ArgumentParser) -> None: def _build_auto_encoder(parser: argparse.ArgumentParser) -> None: parser.add_argument("--coco_path", type=str, help="the path to the COCO data set") parser.add_argument("--weights_path", type=str, help="path to the weights directory") + parser.add_argument("--summary_path", type=str, help="path to the summaries directory") parser.add_argument("category", type=int, help="the COCO category to use") parser.add_argument("num_epochs", type=int, help="the number of epochs to train", default=80) parser.add_argument("iteration", type=int, help="the training iteration") @@ -95,13 +96,21 @@ def _use(args: argparse.Namespace) -> None: def _auto_encoder_train(args: argparse.Namespace) -> None: from twomartens.masterthesis import data from twomartens.masterthesis.aae import train - + import tensorflow as tf + from tensorflow.python.ops import summary_ops_v2 + + tf.enable_eager_execution() coco_path = args.coco_path category = args.category batch_size = 32 coco_data = data.load_coco(coco_path, category, num_epochs=args.num_epochs, batch_size=batch_size) - train.train_simple(coco_data, iteration=args.iteration, weights_prefix=args.weights_path, - channels=3, train_epoch=args.num_epochs, batch_size=batch_size) + train_summary_writer = summary_ops_v2.create_file_writer( + f"{args.summary_path}/train/category-{category}/{args.iteration}" + ) + with train_summary_writer.as_default(): + train.train_simple(coco_data, iteration=args.iteration, + weights_prefix=f"{args.weights_path}/category-{category}", + channels=3, train_epoch=args.num_epochs, batch_size=batch_size) def _bayesian_ssd_train(args: argparse.Namespace) -> None: