diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 5d5d0bb..6418eb2 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -14,7 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""aae.train.py: contains training functionality""" +""" +Training functionality for my AAE implementation. + +This module provides functions to prepare the training data and subsequently +train the Adversarial Auto Encoder. + +Attributes: + GRACE: specifies the number of epochs that the training loss can stagnate or worsen + before the training is stopped early + TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher + +Functions: + prepare_training_data(...): prepares the mnist training data + train(...): trains the AAE models + +""" + import functools import math import os @@ -35,22 +51,26 @@ AdamOptimizer = tf.train.AdamOptimizer tfe = tf.contrib.eager binary_crossentropy = tf.keras.losses.binary_crossentropy -GRACE = 10 +GRACE: int = 10 +TOTAL_LOSS_GRACE_CAP: int = 6 def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int, fold_prefix: str = 'data/data_fold_', batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]: """ - Prepares the MNIST data + Prepares the MNIST training data. - :param test_fold_id: id of test fold - :param inlier_classes: list of class ids that are considered inliers - :param total_classes: total number of classes - :param fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_') - :param batch_size: size of batch (default: 128) - :param folds: number of folds (default: 5) - :return: tuple(train dataset, valid dataset) + Args: + test_fold_id: id of test fold + inlier_classes: list of class ids that are considered inliers + total_classes: total number of classes + fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_') + batch_size: size of batch (default: 128) + folds: number of folds (default: 5) + + Returns: + A tuple (train dataset, valid dataset) """ # prepare data mnist_train = [] @@ -73,17 +93,20 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota # keep only train classes mnist_train = [x for x in mnist_train if x[0] in inlier_classes] - def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]: + def _list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]: """ Converts a list of pairs to a numpy array. - :param list_of_pairs: list of pairs - :return: numpy array + Args: + list_of_pairs: list of pairs + + Returns: + tuple (feature array, label array) """ return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int) - mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train) - mnist_valid_x, mnist_valid_y = list_of_pairs_to_numpy(mnist_valid) + mnist_train_x, mnist_train_y = _list_of_pairs_to_numpy(mnist_train) + mnist_valid_x, mnist_valid_y = _list_of_pairs_to_numpy(mnist_valid) # get dataset train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y)) @@ -102,17 +125,35 @@ def train(dataset: tf.data.Dataset, iteration: int, batch_size: int = 128, train_epoch: int = 80, verbose: bool = True) -> None: """ - Train AAE for given data set. + Trains AAE for given data set. - :param dataset: train dataset - :param iteration: identifier for the current training run - :param weights_prefix: prefix for weights directory - :param channels: number of channels in input image (default: 1) - :param zsize: size of the intermediary z (default: 32) - :param lr: initial learning rate (default: 0.002) - :param batch_size: the size of each batch (default: 128) - :param train_epoch: number of epochs to train (default: 80) - :param verbose: if True prints train progress info to console (default: True) + This function provides early stopping and creates checkpoints after every + epoch as well as after finishing training (or stopping early). When starting + this function with the same ``iteration`` then the training will try to + continue where it ended last time by restoring a saved checkpoint. + The loss values are provided as scalar summaries. Reconstruction and sample + images are provided as summary images. + + Notes: + The training stops early if for ``GRACE`` number of epochs the loss is not + decreasing. Specifically all individual losses are accounted for and any one + of those not decreasing triggers a ``strike``. If the total loss, which is + a sum of all individual losses, is also not decreasing and has a total + value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is + decreased. If in any epoch afterwards all losses are decreasing the grace + period is reset to ``GRACE``. Lastly the training loop will be stopped early + if the grace counter reaches ``0`` at the end of an epoch. + + Args: + dataset: train dataset + iteration: identifier for the current training run + weights_prefix: prefix for weights directory + channels: number of channels in input image (default: 1) + zsize: size of the intermediary z (default: 32) + lr: initial learning rate (default: 0.002) + batch_size: the size of each batch (default: 128) + train_epoch: number of epochs to train (default: 80) + verbose: if True prints train progress info to console (default: True) """ # non-preserved tensors @@ -198,7 +239,7 @@ def train(dataset: tf.data.Dataset, iteration: int, outputs['xd_loss'] + outputs['zd_loss'] if total_loss < total_lowest_loss: total_lowest_loss = total_loss - elif total_loss > 6: + elif total_loss > TOTAL_LOSS_GRACE_CAP: total_strike = True if outputs['encoder_loss'] < encoder_lowest_loss: encoder_lowest_loss = outputs['encoder_loss']