@ -39,7 +39,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
|
|||||||
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||||
batch_size: int = 128, train_epoch: int = 80,
|
batch_size: int = 128, train_epoch: int = 80,
|
||||||
folds: int = 5):
|
folds: int = 5, verbose: bool = True):
|
||||||
"""
|
"""
|
||||||
Train AAE for mnist data set.
|
Train AAE for mnist data set.
|
||||||
|
|
||||||
@ -52,6 +52,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
:param batch_size: size of each batch
|
:param batch_size: size of each batch
|
||||||
:param train_epoch: number of epochs to train
|
:param train_epoch: number of epochs to train
|
||||||
:param folds: number of folds available
|
:param folds: number of folds available
|
||||||
|
:param verbose: if True prints train progress info to console
|
||||||
"""
|
"""
|
||||||
# prepare data
|
# prepare data
|
||||||
mnist_train = []
|
mnist_train = []
|
||||||
@ -134,7 +135,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
# update learning rate
|
# update learning rate
|
||||||
if (epoch + 1) % 30 == 0:
|
if (epoch + 1) % 30 == 0:
|
||||||
learning_rate_var.assign(learning_rate_var.value() / 4)
|
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||||
print("learning rate change!")
|
if verbose:
|
||||||
|
print("learning rate change!")
|
||||||
|
|
||||||
nr_batches = len(mnist_train_x) // batch_size
|
nr_batches = len(mnist_train_x) // batch_size
|
||||||
log_frequency = 10
|
log_frequency = 10
|
||||||
@ -209,15 +211,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
|
|
||||||
epoch_end_time = time.time()
|
epoch_end_time = time.time()
|
||||||
per_epoch_time = epoch_end_time - epoch_start_time
|
per_epoch_time = epoch_end_time - epoch_start_time
|
||||||
|
|
||||||
print((
|
if verbose:
|
||||||
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
print((
|
||||||
f"train time: {per_epoch_time:.2f}, "
|
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
||||||
f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, "
|
f"train time: {per_epoch_time:.2f}, "
|
||||||
f"Z Discriminator loss: {zd_loss_avg.result():.3f}, "
|
f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, "
|
||||||
f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, "
|
f"Z Discriminator loss: {zd_loss_avg.result():.3f}, "
|
||||||
f"Encoder loss: {encoder_loss_avg.result():.3f}"
|
f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, "
|
||||||
))
|
f"Encoder loss: {encoder_loss_avg.result():.3f}"
|
||||||
|
))
|
||||||
|
|
||||||
# save sample image
|
# save sample image
|
||||||
resultsample = decoder(sample).cpu()
|
resultsample = decoder(sample).cpu()
|
||||||
@ -225,8 +228,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
os.makedirs(directory, exist_ok=True)
|
os.makedirs(directory, exist_ok=True)
|
||||||
save_image(resultsample,
|
save_image(resultsample,
|
||||||
'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png')
|
'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png')
|
||||||
|
if verbose:
|
||||||
print("Training finish!... save training results")
|
print("Training finish!... save training results")
|
||||||
|
|
||||||
|
# save trained models
|
||||||
encoder.save_weights("./weights/encoder")
|
encoder.save_weights("./weights/encoder")
|
||||||
decoder.save_weights("./weights/decoder")
|
decoder.save_weights("./weights/decoder")
|
||||||
z_discriminator.save_weights("./weights/z_discriminator")
|
z_discriminator.save_weights("./weights/z_discriminator")
|
||||||
|
|||||||
Reference in New Issue
Block a user