Added weights prefix to parameters
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -97,6 +97,7 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
||||
|
||||
|
||||
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
||||
weights_prefix: str,
|
||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
verbose: bool = True) -> None:
|
||||
@ -106,6 +107,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
||||
:param dataset: train dataset
|
||||
:param iteration: identifier for the current training run
|
||||
:param result_prefix: prefix for result images
|
||||
:param weights_prefix: prefix for weights directory
|
||||
:param channels: number of channels in input image (default: 1)
|
||||
:param zsize: size of the intermediary z (default: 32)
|
||||
:param lr: initial learning rate (default: 0.002)
|
||||
@ -147,7 +149,7 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
||||
total_lowest_loss = math.inf
|
||||
grace_period = GRACE
|
||||
|
||||
checkpoint_dir = './weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/'
|
||||
checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
||||
@ -518,4 +520,6 @@ if __name__ == "__main__":
|
||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||
train(dataset=train_dataset, iteration=iteration, result_prefix='results' + str(inlier_classes[0]) + '/')
|
||||
train(dataset=train_dataset, iteration=iteration,
|
||||
result_prefix='results' + str(inlier_classes[0]) + '/',
|
||||
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
||||
|
||||
Reference in New Issue
Block a user