Added training iteration variable to train function

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 13:36:09 +01:00
parent 3bde913323
commit a0094266b3

View File

@ -37,6 +37,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
iteration: int,
channels: int = 1, zsize: int = 32, lr: float = 0.002,
batch_size: int = 128, train_epoch: int = 80,
folds: int = 5, verbose: bool = True):
@ -46,6 +47,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
:param folding_id: id of fold used for test data
:param inlier_classes: list of class ids that are considered inliers
:param total_classes: total number of classes
:param iteration: identifier for the current training run
:param channels: number of channels in input image
:param zsize: size of the intermediary z
:param lr: learning rate
@ -291,10 +293,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
print("Training stopped early!... save model weights")
# save trained models
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]))
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]))
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]))
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]))
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration))
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration))
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
@ -475,4 +477,4 @@ if __name__ == "__main__":
train_summary_writer = summary_ops_v2.create_file_writer(
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10)
train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10, iteration=iteration)