@ -14,7 +14,23 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""aae.train.py: contains training functionality"""
|
"""
|
||||||
|
Training functionality for my AAE implementation.
|
||||||
|
|
||||||
|
This module provides functions to prepare the training data and subsequently
|
||||||
|
train the Adversarial Auto Encoder.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
GRACE: specifies the number of epochs that the training loss can stagnate or worsen
|
||||||
|
before the training is stopped early
|
||||||
|
TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher
|
||||||
|
|
||||||
|
Functions:
|
||||||
|
prepare_training_data(...): prepares the mnist training data
|
||||||
|
train(...): trains the AAE models
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -35,22 +51,26 @@ AdamOptimizer = tf.train.AdamOptimizer
|
|||||||
tfe = tf.contrib.eager
|
tfe = tf.contrib.eager
|
||||||
binary_crossentropy = tf.keras.losses.binary_crossentropy
|
binary_crossentropy = tf.keras.losses.binary_crossentropy
|
||||||
|
|
||||||
GRACE = 10
|
GRACE: int = 10
|
||||||
|
TOTAL_LOSS_GRACE_CAP: int = 6
|
||||||
|
|
||||||
|
|
||||||
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||||
fold_prefix: str = 'data/data_fold_',
|
fold_prefix: str = 'data/data_fold_',
|
||||||
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||||
"""
|
"""
|
||||||
Prepares the MNIST data
|
Prepares the MNIST training data.
|
||||||
|
|
||||||
:param test_fold_id: id of test fold
|
Args:
|
||||||
:param inlier_classes: list of class ids that are considered inliers
|
test_fold_id: id of test fold
|
||||||
:param total_classes: total number of classes
|
inlier_classes: list of class ids that are considered inliers
|
||||||
:param fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
|
total_classes: total number of classes
|
||||||
:param batch_size: size of batch (default: 128)
|
fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
|
||||||
:param folds: number of folds (default: 5)
|
batch_size: size of batch (default: 128)
|
||||||
:return: tuple(train dataset, valid dataset)
|
folds: number of folds (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (train dataset, valid dataset)
|
||||||
"""
|
"""
|
||||||
# prepare data
|
# prepare data
|
||||||
mnist_train = []
|
mnist_train = []
|
||||||
@ -73,17 +93,20 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
|||||||
# keep only train classes
|
# keep only train classes
|
||||||
mnist_train = [x for x in mnist_train if x[0] in inlier_classes]
|
mnist_train = [x for x in mnist_train if x[0] in inlier_classes]
|
||||||
|
|
||||||
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
def _list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Converts a list of pairs to a numpy array.
|
Converts a list of pairs to a numpy array.
|
||||||
|
|
||||||
:param list_of_pairs: list of pairs
|
Args:
|
||||||
:return: numpy array
|
list_of_pairs: list of pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple (feature array, label array)
|
||||||
"""
|
"""
|
||||||
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
|
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
|
||||||
|
|
||||||
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
mnist_train_x, mnist_train_y = _list_of_pairs_to_numpy(mnist_train)
|
||||||
mnist_valid_x, mnist_valid_y = list_of_pairs_to_numpy(mnist_valid)
|
mnist_valid_x, mnist_valid_y = _list_of_pairs_to_numpy(mnist_valid)
|
||||||
|
|
||||||
# get dataset
|
# get dataset
|
||||||
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||||
@ -102,17 +125,35 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
batch_size: int = 128, train_epoch: int = 80,
|
batch_size: int = 128, train_epoch: int = 80,
|
||||||
verbose: bool = True) -> None:
|
verbose: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Train AAE for given data set.
|
Trains AAE for given data set.
|
||||||
|
|
||||||
:param dataset: train dataset
|
This function provides early stopping and creates checkpoints after every
|
||||||
:param iteration: identifier for the current training run
|
epoch as well as after finishing training (or stopping early). When starting
|
||||||
:param weights_prefix: prefix for weights directory
|
this function with the same ``iteration`` then the training will try to
|
||||||
:param channels: number of channels in input image (default: 1)
|
continue where it ended last time by restoring a saved checkpoint.
|
||||||
:param zsize: size of the intermediary z (default: 32)
|
The loss values are provided as scalar summaries. Reconstruction and sample
|
||||||
:param lr: initial learning rate (default: 0.002)
|
images are provided as summary images.
|
||||||
:param batch_size: the size of each batch (default: 128)
|
|
||||||
:param train_epoch: number of epochs to train (default: 80)
|
Notes:
|
||||||
:param verbose: if True prints train progress info to console (default: True)
|
The training stops early if for ``GRACE`` number of epochs the loss is not
|
||||||
|
decreasing. Specifically all individual losses are accounted for and any one
|
||||||
|
of those not decreasing triggers a ``strike``. If the total loss, which is
|
||||||
|
a sum of all individual losses, is also not decreasing and has a total
|
||||||
|
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
|
||||||
|
decreased. If in any epoch afterwards all losses are decreasing the grace
|
||||||
|
period is reset to ``GRACE``. Lastly the training loop will be stopped early
|
||||||
|
if the grace counter reaches ``0`` at the end of an epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: train dataset
|
||||||
|
iteration: identifier for the current training run
|
||||||
|
weights_prefix: prefix for weights directory
|
||||||
|
channels: number of channels in input image (default: 1)
|
||||||
|
zsize: size of the intermediary z (default: 32)
|
||||||
|
lr: initial learning rate (default: 0.002)
|
||||||
|
batch_size: the size of each batch (default: 128)
|
||||||
|
train_epoch: number of epochs to train (default: 80)
|
||||||
|
verbose: if True prints train progress info to console (default: True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# non-preserved tensors
|
# non-preserved tensors
|
||||||
@ -198,7 +239,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
outputs['xd_loss'] + outputs['zd_loss']
|
outputs['xd_loss'] + outputs['zd_loss']
|
||||||
if total_loss < total_lowest_loss:
|
if total_loss < total_lowest_loss:
|
||||||
total_lowest_loss = total_loss
|
total_lowest_loss = total_loss
|
||||||
elif total_loss > 6:
|
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
||||||
total_strike = True
|
total_strike = True
|
||||||
if outputs['encoder_loss'] < encoder_lowest_loss:
|
if outputs['encoder_loss'] < encoder_lowest_loss:
|
||||||
encoder_lowest_loss = outputs['encoder_loss']
|
encoder_lowest_loss = outputs['encoder_loss']
|
||||||
|
|||||||
Reference in New Issue
Block a user