Added weights prefix to parameters
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -97,6 +97,7 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
|||||||
|
|
||||||
|
|
||||||
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
||||||
|
weights_prefix: str,
|
||||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||||
batch_size: int = 128, train_epoch: int = 80,
|
batch_size: int = 128, train_epoch: int = 80,
|
||||||
verbose: bool = True) -> None:
|
verbose: bool = True) -> None:
|
||||||
@ -106,6 +107,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
:param dataset: train dataset
|
:param dataset: train dataset
|
||||||
:param iteration: identifier for the current training run
|
:param iteration: identifier for the current training run
|
||||||
:param result_prefix: prefix for result images
|
:param result_prefix: prefix for result images
|
||||||
|
:param weights_prefix: prefix for weights directory
|
||||||
:param channels: number of channels in input image (default: 1)
|
:param channels: number of channels in input image (default: 1)
|
||||||
:param zsize: size of the intermediary z (default: 32)
|
:param zsize: size of the intermediary z (default: 32)
|
||||||
:param lr: initial learning rate (default: 0.002)
|
:param lr: initial learning rate (default: 0.002)
|
||||||
@ -147,7 +149,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
total_lowest_loss = math.inf
|
total_lowest_loss = math.inf
|
||||||
grace_period = GRACE
|
grace_period = GRACE
|
||||||
|
|
||||||
checkpoint_dir = './weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/'
|
checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/')
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
||||||
@ -518,4 +520,6 @@ if __name__ == "__main__":
|
|||||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||||
train(dataset=train_dataset, iteration=iteration, result_prefix='results' + str(inlier_classes[0]) + '/')
|
train(dataset=train_dataset, iteration=iteration,
|
||||||
|
result_prefix='results' + str(inlier_classes[0]) + '/',
|
||||||
|
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
||||||
|
|||||||
Reference in New Issue
Block a user