Improved docstrings

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 17:12:41 +01:00
parent 028513b404
commit 1d4f56e40d

View File

@ -14,7 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""aae.train.py: contains training functionality"""
"""
Training functionality for my AAE implementation.
This module provides functions to prepare the training data and subsequently
train the Adversarial Auto Encoder.
Attributes:
GRACE: specifies the number of epochs that the training loss can stagnate or worsen
before the training is stopped early
TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher
Functions:
prepare_training_data(...): prepares the mnist training data
train(...): trains the AAE models
"""
import functools
import math
import os
@ -35,22 +51,26 @@ AdamOptimizer = tf.train.AdamOptimizer
tfe = tf.contrib.eager
binary_crossentropy = tf.keras.losses.binary_crossentropy
GRACE = 10
GRACE: int = 10
TOTAL_LOSS_GRACE_CAP: int = 6
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
fold_prefix: str = 'data/data_fold_',
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
"""
Prepares the MNIST data
Prepares the MNIST training data.
:param test_fold_id: id of test fold
:param inlier_classes: list of class ids that are considered inliers
:param total_classes: total number of classes
:param fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
:param batch_size: size of batch (default: 128)
:param folds: number of folds (default: 5)
:return: tuple(train dataset, valid dataset)
Args:
test_fold_id: id of test fold
inlier_classes: list of class ids that are considered inliers
total_classes: total number of classes
fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
batch_size: size of batch (default: 128)
folds: number of folds (default: 5)
Returns:
A tuple (train dataset, valid dataset)
"""
# prepare data
mnist_train = []
@ -73,17 +93,20 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
# keep only train classes
mnist_train = [x for x in mnist_train if x[0] in inlier_classes]
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
def _list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
"""
Converts a list of pairs to a numpy array.
:param list_of_pairs: list of pairs
:return: numpy array
Args:
list_of_pairs: list of pairs
Returns:
tuple (feature array, label array)
"""
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
mnist_valid_x, mnist_valid_y = list_of_pairs_to_numpy(mnist_valid)
mnist_train_x, mnist_train_y = _list_of_pairs_to_numpy(mnist_train)
mnist_valid_x, mnist_valid_y = _list_of_pairs_to_numpy(mnist_valid)
# get dataset
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
@ -102,17 +125,35 @@ def train(dataset: tf.data.Dataset, iteration: int,
batch_size: int = 128, train_epoch: int = 80,
verbose: bool = True) -> None:
"""
Train AAE for given data set.
Trains AAE for given data set.
:param dataset: train dataset
:param iteration: identifier for the current training run
:param weights_prefix: prefix for weights directory
:param channels: number of channels in input image (default: 1)
:param zsize: size of the intermediary z (default: 32)
:param lr: initial learning rate (default: 0.002)
:param batch_size: the size of each batch (default: 128)
:param train_epoch: number of epochs to train (default: 80)
:param verbose: if True prints train progress info to console (default: True)
This function provides early stopping and creates checkpoints after every
epoch as well as after finishing training (or stopping early). When starting
this function with the same ``iteration`` then the training will try to
continue where it ended last time by restoring a saved checkpoint.
The loss values are provided as scalar summaries. Reconstruction and sample
images are provided as summary images.
Notes:
The training stops early if for ``GRACE`` number of epochs the loss is not
decreasing. Specifically all individual losses are accounted for and any one
of those not decreasing triggers a ``strike``. If the total loss, which is
a sum of all individual losses, is also not decreasing and has a total
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
decreased. If in any epoch afterwards all losses are decreasing the grace
period is reset to ``GRACE``. Lastly the training loop will be stopped early
if the grace counter reaches ``0`` at the end of an epoch.
Args:
dataset: train dataset
iteration: identifier for the current training run
weights_prefix: prefix for weights directory
channels: number of channels in input image (default: 1)
zsize: size of the intermediary z (default: 32)
lr: initial learning rate (default: 0.002)
batch_size: the size of each batch (default: 128)
train_epoch: number of epochs to train (default: 80)
verbose: if True prints train progress info to console (default: True)
"""
# non-preserved tensors
@ -198,7 +239,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
outputs['xd_loss'] + outputs['zd_loss']
if total_loss < total_lowest_loss:
total_lowest_loss = total_loss
elif total_loss > 6:
elif total_loss > TOTAL_LOSS_GRACE_CAP:
total_strike = True
if outputs['encoder_loss'] < encoder_lowest_loss:
encoder_lowest_loss = outputs['encoder_loss']