From 658973d90d7b37dd218cb3be75d803b154c18f55 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 15:34:49 +0100 Subject: [PATCH] Added weights prefix to parameters Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 078e498..e67e7b3 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -97,6 +97,7 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, + weights_prefix: str, channels: int = 1, zsize: int = 32, lr: float = 0.002, batch_size: int = 128, train_epoch: int = 80, verbose: bool = True) -> None: @@ -106,6 +107,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, :param dataset: train dataset :param iteration: identifier for the current training run :param result_prefix: prefix for result images + :param weights_prefix: prefix for weights directory :param channels: number of channels in input image (default: 1) :param zsize: size of the intermediary z (default: 32) :param lr: initial learning rate (default: 0.002) @@ -147,7 +149,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, total_lowest_loss = math.inf grace_period = GRACE - checkpoint_dir = './weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/' + checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/') os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) @@ -518,4 +520,6 @@ if __name__ == "__main__": train_summary_writer = summary_ops_v2.create_file_writer( './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): - train(dataset=train_dataset, iteration=iteration, result_prefix='results' + str(inlier_classes[0]) + '/') + train(dataset=train_dataset, iteration=iteration, + result_prefix='results' + str(inlier_classes[0]) + '/', + weights_prefix='weights/' + str(inlier_classes[0]) + '/')