Added weights prefix to parameters

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 15:34:49 +01:00
parent 88ffd4f879
commit 658973d90d

View File

@ -97,6 +97,7 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
weights_prefix: str,
channels: int = 1, zsize: int = 32, lr: float = 0.002, channels: int = 1, zsize: int = 32, lr: float = 0.002,
batch_size: int = 128, train_epoch: int = 80, batch_size: int = 128, train_epoch: int = 80,
verbose: bool = True) -> None: verbose: bool = True) -> None:
@ -106,6 +107,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
:param dataset: train dataset :param dataset: train dataset
:param iteration: identifier for the current training run :param iteration: identifier for the current training run
:param result_prefix: prefix for result images :param result_prefix: prefix for result images
:param weights_prefix: prefix for weights directory
:param channels: number of channels in input image (default: 1) :param channels: number of channels in input image (default: 1)
:param zsize: size of the intermediary z (default: 32) :param zsize: size of the intermediary z (default: 32)
:param lr: initial learning rate (default: 0.002) :param lr: initial learning rate (default: 0.002)
@ -147,7 +149,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
total_lowest_loss = math.inf total_lowest_loss = math.inf
grace_period = GRACE grace_period = GRACE
checkpoint_dir = './weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/' checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/')
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
@ -518,4 +520,6 @@ if __name__ == "__main__":
train_summary_writer = summary_ops_v2.create_file_writer( train_summary_writer = summary_ops_v2.create_file_writer(
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
train(dataset=train_dataset, iteration=iteration, result_prefix='results' + str(inlier_classes[0]) + '/') train(dataset=train_dataset, iteration=iteration,
result_prefix='results' + str(inlier_classes[0]) + '/',
weights_prefix='weights/' + str(inlier_classes[0]) + '/')