From 416cb4c7b594e5859f352e9aada70c6a186ce3c0 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 22:42:49 +0100 Subject: [PATCH] Reformatted code to ensure all parameters are on their own line Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 90 +++++++++++++++--------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 808ab15..6c2f849 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -62,9 +62,12 @@ TOTAL_LOSS_GRACE_CAP: int = 6 LOG_FREQUENCY: int = 10 -def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int, +def prepare_training_data(test_fold_id: int, + inlier_classes: Sequence[int], + total_classes: int, fold_prefix: str = 'data/data_fold_', - batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]: + batch_size: int = 128, + folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]: """ Prepares the MNIST training data. @@ -126,11 +129,16 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota return train_dataset, valid_dataset -def train(dataset: tf.data.Dataset, iteration: int, +def train(dataset: tf.data.Dataset, + iteration: int, weights_prefix: str, - channels: int = 1, zsize: int = 32, lr: float = 0.002, - batch_size: int = 128, train_epoch: int = 80, - verbose: bool = True, early_stopping: bool = False) -> None: + channels: int = 1, + zsize: int = 32, + lr: float = 0.002, + batch_size: int = 128, + train_epoch: int = 80, + verbose: bool = True, + early_stopping: bool = False) -> None: """ Trains AAE for given data set. @@ -141,16 +149,6 @@ def train(dataset: tf.data.Dataset, iteration: int, The loss values are provided as scalar summaries. Reconstruction and sample images are provided as summary images. - Notes: - The training stops early if for ``GRACE`` number of epochs the loss is not - decreasing. Specifically all individual losses are accounted for and any one - of those not decreasing triggers a ``strike``. If the total loss, which is - a sum of all individual losses, is also not decreasing and has a total - value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is - decreased. If in any epoch afterwards all losses are decreasing the grace - period is reset to ``GRACE``. Lastly the training loop will be stopped early - if the grace counter reaches ``0`` at the end of an epoch. - Args: dataset: train dataset iteration: identifier for the current training run @@ -162,6 +160,16 @@ def train(dataset: tf.data.Dataset, iteration: int, train_epoch: number of epochs to train (default: 80) verbose: if True prints train progress info to console (default: True) early_stopping: if True the early stopping mechanic is enabled (default: False) + + Notes: + The training stops early if for ``GRACE`` number of epochs the loss is not + decreasing. Specifically all individual losses are accounted for and any one + of those not decreasing triggers a ``strike``. If the total loss, which is + a sum of all individual losses, is also not decreasing and has a total + value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is + decreased. If in any epoch afterwards all losses are decreasing the grace + period is reset to ``GRACE``. Lastly the training loop will be stopped early + if the grace counter reaches ``0`` at the end of an epoch. """ # non-preserved tensors @@ -308,17 +316,24 @@ def train(dataset: tf.data.Dataset, iteration: int, checkpoint.save(checkpoint_prefix) -def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor, +def _train_one_epoch(epoch: int, + dataset: tf.data.Dataset, + targets_real: tf.Tensor, verbose: bool, - targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable], + targets_fake: tf.Tensor, + z_generator: Callable[[], tf.Variable], learning_rate_var: tf.Variable, - decoder: model.Decoder, encoder: model.Encoder, x_discriminator: model.XDiscriminator, - z_discriminator: model.ZDiscriminator, decoder_optimizer: tf.train.Optimizer, + decoder: model.Decoder, + encoder: model.Encoder, + x_discriminator: model.XDiscriminator, + z_discriminator: model.ZDiscriminator, + decoder_optimizer: tf.train.Optimizer, x_discriminator_optimizer: tf.train.Optimizer, z_discriminator_optimizer: tf.train.Optimizer, enc_dec_optimizer: tf.train.Optimizer, global_step: tf.Variable, - global_step_xd: tf.Variable, global_step_zd: tf.Variable, + global_step_xd: tf.Variable, + global_step_zd: tf.Variable, global_step_decoder: tf.Variable, global_step_enc_dec: tf.Variable, epoch_var: tf.Variable) -> Dict[str, float]: @@ -415,10 +430,13 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens return outputs -def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: model.Decoder, +def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, + decoder: model.Decoder, optimizer: tf.train.Optimizer, - inputs: tf.Tensor, targets_real: tf.Tensor, - targets_fake: tf.Tensor, global_step: tf.Variable, + inputs: tf.Tensor, + targets_real: tf.Tensor, + targets_fake: tf.Tensor, + global_step: tf.Variable, global_step_xd: tf.Variable, z_generator: Callable[[], tf.Variable]) -> tf.Tensor: """ @@ -465,9 +483,11 @@ def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: m return _xd_train_loss -def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscriminator, +def _train_decoder_step(decoder: model.Decoder, + x_discriminator: model.XDiscriminator, optimizer: tf.train.Optimizer, - targets: tf.Tensor, global_step: tf.Variable, + targets: tf.Tensor, + global_step: tf.Variable, global_step_decoder: tf.Variable, z_generator: Callable[[], tf.Variable]) -> tf.Tensor: """ @@ -504,10 +524,13 @@ def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscrimi return _decoder_train_loss -def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: model.Encoder, +def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, + encoder: model.Encoder, optimizer: tf.train.Optimizer, - inputs: tf.Tensor, targets_real: tf.Tensor, - targets_fake: tf.Tensor, global_step: tf.Variable, + inputs: tf.Tensor, + targets_real: tf.Tensor, + targets_fake: tf.Tensor, + global_step: tf.Variable, global_step_zd: tf.Variable, z_generator: Callable[[], tf.Variable]) -> tf.Tensor: """ @@ -555,9 +578,12 @@ def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: m return _zd_train_loss -def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder, z_discriminator: model.ZDiscriminator, - optimizer: tf.train.Optimizer, inputs: tf.Tensor, - targets: tf.Tensor, global_step: tf.Variable, +def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder, + z_discriminator: model.ZDiscriminator, + optimizer: tf.train.Optimizer, + inputs: tf.Tensor, + targets: tf.Tensor, + global_step: tf.Variable, global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: """ Trains the encoder and decoder jointly for one step (one batch).