@ -14,7 +14,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""aae.train.py: contains training functionality"""
|
||||
"""
|
||||
Training functionality for my AAE implementation.
|
||||
|
||||
This module provides functions to prepare the training data and subsequently
|
||||
train the Adversarial Auto Encoder.
|
||||
|
||||
Attributes:
|
||||
GRACE: specifies the number of epochs that the training loss can stagnate or worsen
|
||||
before the training is stopped early
|
||||
TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher
|
||||
|
||||
Functions:
|
||||
prepare_training_data(...): prepares the mnist training data
|
||||
train(...): trains the AAE models
|
||||
|
||||
"""
|
||||
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
@ -35,22 +51,26 @@ AdamOptimizer = tf.train.AdamOptimizer
|
||||
tfe = tf.contrib.eager
|
||||
binary_crossentropy = tf.keras.losses.binary_crossentropy
|
||||
|
||||
GRACE = 10
|
||||
GRACE: int = 10
|
||||
TOTAL_LOSS_GRACE_CAP: int = 6
|
||||
|
||||
|
||||
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||
fold_prefix: str = 'data/data_fold_',
|
||||
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||
"""
|
||||
Prepares the MNIST data
|
||||
Prepares the MNIST training data.
|
||||
|
||||
:param test_fold_id: id of test fold
|
||||
:param inlier_classes: list of class ids that are considered inliers
|
||||
:param total_classes: total number of classes
|
||||
:param fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
|
||||
:param batch_size: size of batch (default: 128)
|
||||
:param folds: number of folds (default: 5)
|
||||
:return: tuple(train dataset, valid dataset)
|
||||
Args:
|
||||
test_fold_id: id of test fold
|
||||
inlier_classes: list of class ids that are considered inliers
|
||||
total_classes: total number of classes
|
||||
fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
|
||||
batch_size: size of batch (default: 128)
|
||||
folds: number of folds (default: 5)
|
||||
|
||||
Returns:
|
||||
A tuple (train dataset, valid dataset)
|
||||
"""
|
||||
# prepare data
|
||||
mnist_train = []
|
||||
@ -73,17 +93,20 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
||||
# keep only train classes
|
||||
mnist_train = [x for x in mnist_train if x[0] in inlier_classes]
|
||||
|
||||
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
def _list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Converts a list of pairs to a numpy array.
|
||||
|
||||
:param list_of_pairs: list of pairs
|
||||
:return: numpy array
|
||||
Args:
|
||||
list_of_pairs: list of pairs
|
||||
|
||||
Returns:
|
||||
tuple (feature array, label array)
|
||||
"""
|
||||
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
|
||||
|
||||
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
||||
mnist_valid_x, mnist_valid_y = list_of_pairs_to_numpy(mnist_valid)
|
||||
mnist_train_x, mnist_train_y = _list_of_pairs_to_numpy(mnist_train)
|
||||
mnist_valid_x, mnist_valid_y = _list_of_pairs_to_numpy(mnist_valid)
|
||||
|
||||
# get dataset
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||
@ -102,17 +125,35 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
verbose: bool = True) -> None:
|
||||
"""
|
||||
Train AAE for given data set.
|
||||
Trains AAE for given data set.
|
||||
|
||||
:param dataset: train dataset
|
||||
:param iteration: identifier for the current training run
|
||||
:param weights_prefix: prefix for weights directory
|
||||
:param channels: number of channels in input image (default: 1)
|
||||
:param zsize: size of the intermediary z (default: 32)
|
||||
:param lr: initial learning rate (default: 0.002)
|
||||
:param batch_size: the size of each batch (default: 128)
|
||||
:param train_epoch: number of epochs to train (default: 80)
|
||||
:param verbose: if True prints train progress info to console (default: True)
|
||||
This function provides early stopping and creates checkpoints after every
|
||||
epoch as well as after finishing training (or stopping early). When starting
|
||||
this function with the same ``iteration`` then the training will try to
|
||||
continue where it ended last time by restoring a saved checkpoint.
|
||||
The loss values are provided as scalar summaries. Reconstruction and sample
|
||||
images are provided as summary images.
|
||||
|
||||
Notes:
|
||||
The training stops early if for ``GRACE`` number of epochs the loss is not
|
||||
decreasing. Specifically all individual losses are accounted for and any one
|
||||
of those not decreasing triggers a ``strike``. If the total loss, which is
|
||||
a sum of all individual losses, is also not decreasing and has a total
|
||||
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
|
||||
decreased. If in any epoch afterwards all losses are decreasing the grace
|
||||
period is reset to ``GRACE``. Lastly the training loop will be stopped early
|
||||
if the grace counter reaches ``0`` at the end of an epoch.
|
||||
|
||||
Args:
|
||||
dataset: train dataset
|
||||
iteration: identifier for the current training run
|
||||
weights_prefix: prefix for weights directory
|
||||
channels: number of channels in input image (default: 1)
|
||||
zsize: size of the intermediary z (default: 32)
|
||||
lr: initial learning rate (default: 0.002)
|
||||
batch_size: the size of each batch (default: 128)
|
||||
train_epoch: number of epochs to train (default: 80)
|
||||
verbose: if True prints train progress info to console (default: True)
|
||||
"""
|
||||
|
||||
# non-preserved tensors
|
||||
@ -198,7 +239,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
outputs['xd_loss'] + outputs['zd_loss']
|
||||
if total_loss < total_lowest_loss:
|
||||
total_lowest_loss = total_loss
|
||||
elif total_loss > 6:
|
||||
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
||||
total_strike = True
|
||||
if outputs['encoder_loss'] < encoder_lowest_loss:
|
||||
encoder_lowest_loss = outputs['encoder_loss']
|
||||
|
||||
Reference in New Issue
Block a user