@ -39,7 +39,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
|
||||
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
folds: int = 5):
|
||||
folds: int = 5, verbose: bool = True):
|
||||
"""
|
||||
Train AAE for mnist data set.
|
||||
|
||||
@ -52,6 +52,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
:param batch_size: size of each batch
|
||||
:param train_epoch: number of epochs to train
|
||||
:param folds: number of folds available
|
||||
:param verbose: if True prints train progress info to console
|
||||
"""
|
||||
# prepare data
|
||||
mnist_train = []
|
||||
@ -134,6 +135,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
# update learning rate
|
||||
if (epoch + 1) % 30 == 0:
|
||||
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||
if verbose:
|
||||
print("learning rate change!")
|
||||
|
||||
nr_batches = len(mnist_train_x) // batch_size
|
||||
@ -210,6 +212,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
epoch_end_time = time.time()
|
||||
per_epoch_time = epoch_end_time - epoch_start_time
|
||||
|
||||
if verbose:
|
||||
print((
|
||||
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
||||
f"train time: {per_epoch_time:.2f}, "
|
||||
@ -225,8 +228,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
save_image(resultsample,
|
||||
'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png')
|
||||
|
||||
if verbose:
|
||||
print("Training finish!... save training results")
|
||||
|
||||
# save trained models
|
||||
encoder.save_weights("./weights/encoder")
|
||||
decoder.save_weights("./weights/decoder")
|
||||
z_discriminator.save_weights("./weights/z_discriminator")
|
||||
|
||||
Reference in New Issue
Block a user