Reformatted code to ensure all parameters are on their own line

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 22:42:49 +01:00
parent 9909b5fdfc
commit 416cb4c7b5

View File

@ -62,9 +62,12 @@ TOTAL_LOSS_GRACE_CAP: int = 6
LOG_FREQUENCY: int = 10 LOG_FREQUENCY: int = 10
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int, def prepare_training_data(test_fold_id: int,
inlier_classes: Sequence[int],
total_classes: int,
fold_prefix: str = 'data/data_fold_', fold_prefix: str = 'data/data_fold_',
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]: batch_size: int = 128,
folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
""" """
Prepares the MNIST training data. Prepares the MNIST training data.
@ -126,11 +129,16 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
return train_dataset, valid_dataset return train_dataset, valid_dataset
def train(dataset: tf.data.Dataset, iteration: int, def train(dataset: tf.data.Dataset,
iteration: int,
weights_prefix: str, weights_prefix: str,
channels: int = 1, zsize: int = 32, lr: float = 0.002, channels: int = 1,
batch_size: int = 128, train_epoch: int = 80, zsize: int = 32,
verbose: bool = True, early_stopping: bool = False) -> None: lr: float = 0.002,
batch_size: int = 128,
train_epoch: int = 80,
verbose: bool = True,
early_stopping: bool = False) -> None:
""" """
Trains AAE for given data set. Trains AAE for given data set.
@ -141,16 +149,6 @@ def train(dataset: tf.data.Dataset, iteration: int,
The loss values are provided as scalar summaries. Reconstruction and sample The loss values are provided as scalar summaries. Reconstruction and sample
images are provided as summary images. images are provided as summary images.
Notes:
The training stops early if for ``GRACE`` number of epochs the loss is not
decreasing. Specifically all individual losses are accounted for and any one
of those not decreasing triggers a ``strike``. If the total loss, which is
a sum of all individual losses, is also not decreasing and has a total
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
decreased. If in any epoch afterwards all losses are decreasing the grace
period is reset to ``GRACE``. Lastly the training loop will be stopped early
if the grace counter reaches ``0`` at the end of an epoch.
Args: Args:
dataset: train dataset dataset: train dataset
iteration: identifier for the current training run iteration: identifier for the current training run
@ -162,6 +160,16 @@ def train(dataset: tf.data.Dataset, iteration: int,
train_epoch: number of epochs to train (default: 80) train_epoch: number of epochs to train (default: 80)
verbose: if True prints train progress info to console (default: True) verbose: if True prints train progress info to console (default: True)
early_stopping: if True the early stopping mechanic is enabled (default: False) early_stopping: if True the early stopping mechanic is enabled (default: False)
Notes:
The training stops early if for ``GRACE`` number of epochs the loss is not
decreasing. Specifically all individual losses are accounted for and any one
of those not decreasing triggers a ``strike``. If the total loss, which is
a sum of all individual losses, is also not decreasing and has a total
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
decreased. If in any epoch afterwards all losses are decreasing the grace
period is reset to ``GRACE``. Lastly the training loop will be stopped early
if the grace counter reaches ``0`` at the end of an epoch.
""" """
# non-preserved tensors # non-preserved tensors
@ -308,17 +316,24 @@ def train(dataset: tf.data.Dataset, iteration: int,
checkpoint.save(checkpoint_prefix) checkpoint.save(checkpoint_prefix)
def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor, def _train_one_epoch(epoch: int,
dataset: tf.data.Dataset,
targets_real: tf.Tensor,
verbose: bool, verbose: bool,
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable], targets_fake: tf.Tensor,
z_generator: Callable[[], tf.Variable],
learning_rate_var: tf.Variable, learning_rate_var: tf.Variable,
decoder: model.Decoder, encoder: model.Encoder, x_discriminator: model.XDiscriminator, decoder: model.Decoder,
z_discriminator: model.ZDiscriminator, decoder_optimizer: tf.train.Optimizer, encoder: model.Encoder,
x_discriminator: model.XDiscriminator,
z_discriminator: model.ZDiscriminator,
decoder_optimizer: tf.train.Optimizer,
x_discriminator_optimizer: tf.train.Optimizer, x_discriminator_optimizer: tf.train.Optimizer,
z_discriminator_optimizer: tf.train.Optimizer, z_discriminator_optimizer: tf.train.Optimizer,
enc_dec_optimizer: tf.train.Optimizer, enc_dec_optimizer: tf.train.Optimizer,
global_step: tf.Variable, global_step: tf.Variable,
global_step_xd: tf.Variable, global_step_zd: tf.Variable, global_step_xd: tf.Variable,
global_step_zd: tf.Variable,
global_step_decoder: tf.Variable, global_step_decoder: tf.Variable,
global_step_enc_dec: tf.Variable, global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]: epoch_var: tf.Variable) -> Dict[str, float]:
@ -415,10 +430,13 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
return outputs return outputs
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: model.Decoder, def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator,
decoder: model.Decoder,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor, inputs: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable, targets_real: tf.Tensor,
targets_fake: tf.Tensor,
global_step: tf.Variable,
global_step_xd: tf.Variable, global_step_xd: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor: z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
""" """
@ -465,9 +483,11 @@ def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: m
return _xd_train_loss return _xd_train_loss
def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscriminator, def _train_decoder_step(decoder: model.Decoder,
x_discriminator: model.XDiscriminator,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
targets: tf.Tensor, global_step: tf.Variable, targets: tf.Tensor,
global_step: tf.Variable,
global_step_decoder: tf.Variable, global_step_decoder: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor: z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
""" """
@ -504,10 +524,13 @@ def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscrimi
return _decoder_train_loss return _decoder_train_loss
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: model.Encoder, def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator,
encoder: model.Encoder,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor, inputs: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable, targets_real: tf.Tensor,
targets_fake: tf.Tensor,
global_step: tf.Variable,
global_step_zd: tf.Variable, global_step_zd: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor: z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
""" """
@ -555,9 +578,12 @@ def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: m
return _zd_train_loss return _zd_train_loss
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder, z_discriminator: model.ZDiscriminator, def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder,
optimizer: tf.train.Optimizer, inputs: tf.Tensor, z_discriminator: model.ZDiscriminator,
targets: tf.Tensor, global_step: tf.Variable, optimizer: tf.train.Optimizer,
inputs: tf.Tensor,
targets: tf.Tensor,
global_step: tf.Variable,
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
""" """
Trains the encoder and decoder jointly for one step (one batch). Trains the encoder and decoder jointly for one step (one batch).