Added training iteration variable to train function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -37,6 +37,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
|
|||||||
|
|
||||||
|
|
||||||
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||||
|
iteration: int,
|
||||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||||
batch_size: int = 128, train_epoch: int = 80,
|
batch_size: int = 128, train_epoch: int = 80,
|
||||||
folds: int = 5, verbose: bool = True):
|
folds: int = 5, verbose: bool = True):
|
||||||
@ -46,6 +47,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
:param folding_id: id of fold used for test data
|
:param folding_id: id of fold used for test data
|
||||||
:param inlier_classes: list of class ids that are considered inliers
|
:param inlier_classes: list of class ids that are considered inliers
|
||||||
:param total_classes: total number of classes
|
:param total_classes: total number of classes
|
||||||
|
:param iteration: identifier for the current training run
|
||||||
:param channels: number of channels in input image
|
:param channels: number of channels in input image
|
||||||
:param zsize: size of the intermediary z
|
:param zsize: size of the intermediary z
|
||||||
:param lr: learning rate
|
:param lr: learning rate
|
||||||
@ -291,10 +293,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
print("Training stopped early!... save model weights")
|
print("Training stopped early!... save model weights")
|
||||||
|
|
||||||
# save trained models
|
# save trained models
|
||||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]))
|
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]))
|
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]))
|
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]))
|
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
|
|
||||||
|
|
||||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||||
@ -475,4 +477,4 @@ if __name__ == "__main__":
|
|||||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||||
train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10)
|
train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10, iteration=iteration)
|
||||||
|
|||||||
Reference in New Issue
Block a user