Reformatted code to ensure all parameters are on their own line

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 22:42:49 +01:00
parent 9909b5fdfc
commit 416cb4c7b5

View File

@ -62,9 +62,12 @@ TOTAL_LOSS_GRACE_CAP: int = 6
LOG_FREQUENCY: int = 10
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
def prepare_training_data(test_fold_id: int,
inlier_classes: Sequence[int],
total_classes: int,
fold_prefix: str = 'data/data_fold_',
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
batch_size: int = 128,
folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
"""
Prepares the MNIST training data.
@ -126,11 +129,16 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
return train_dataset, valid_dataset
def train(dataset: tf.data.Dataset, iteration: int,
def train(dataset: tf.data.Dataset,
iteration: int,
weights_prefix: str,
channels: int = 1, zsize: int = 32, lr: float = 0.002,
batch_size: int = 128, train_epoch: int = 80,
verbose: bool = True, early_stopping: bool = False) -> None:
channels: int = 1,
zsize: int = 32,
lr: float = 0.002,
batch_size: int = 128,
train_epoch: int = 80,
verbose: bool = True,
early_stopping: bool = False) -> None:
"""
Trains AAE for given data set.
@ -141,16 +149,6 @@ def train(dataset: tf.data.Dataset, iteration: int,
The loss values are provided as scalar summaries. Reconstruction and sample
images are provided as summary images.
Notes:
The training stops early if for ``GRACE`` number of epochs the loss is not
decreasing. Specifically all individual losses are accounted for and any one
of those not decreasing triggers a ``strike``. If the total loss, which is
a sum of all individual losses, is also not decreasing and has a total
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
decreased. If in any epoch afterwards all losses are decreasing the grace
period is reset to ``GRACE``. Lastly the training loop will be stopped early
if the grace counter reaches ``0`` at the end of an epoch.
Args:
dataset: train dataset
iteration: identifier for the current training run
@ -162,6 +160,16 @@ def train(dataset: tf.data.Dataset, iteration: int,
train_epoch: number of epochs to train (default: 80)
verbose: if True prints train progress info to console (default: True)
early_stopping: if True the early stopping mechanic is enabled (default: False)
Notes:
The training stops early if for ``GRACE`` number of epochs the loss is not
decreasing. Specifically all individual losses are accounted for and any one
of those not decreasing triggers a ``strike``. If the total loss, which is
a sum of all individual losses, is also not decreasing and has a total
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
decreased. If in any epoch afterwards all losses are decreasing the grace
period is reset to ``GRACE``. Lastly the training loop will be stopped early
if the grace counter reaches ``0`` at the end of an epoch.
"""
# non-preserved tensors
@ -308,17 +316,24 @@ def train(dataset: tf.data.Dataset, iteration: int,
checkpoint.save(checkpoint_prefix)
def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor,
def _train_one_epoch(epoch: int,
dataset: tf.data.Dataset,
targets_real: tf.Tensor,
verbose: bool,
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable],
targets_fake: tf.Tensor,
z_generator: Callable[[], tf.Variable],
learning_rate_var: tf.Variable,
decoder: model.Decoder, encoder: model.Encoder, x_discriminator: model.XDiscriminator,
z_discriminator: model.ZDiscriminator, decoder_optimizer: tf.train.Optimizer,
decoder: model.Decoder,
encoder: model.Encoder,
x_discriminator: model.XDiscriminator,
z_discriminator: model.ZDiscriminator,
decoder_optimizer: tf.train.Optimizer,
x_discriminator_optimizer: tf.train.Optimizer,
z_discriminator_optimizer: tf.train.Optimizer,
enc_dec_optimizer: tf.train.Optimizer,
global_step: tf.Variable,
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
global_step_xd: tf.Variable,
global_step_zd: tf.Variable,
global_step_decoder: tf.Variable,
global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]:
@ -415,10 +430,13 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
return outputs
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: model.Decoder,
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator,
decoder: model.Decoder,
optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable,
inputs: tf.Tensor,
targets_real: tf.Tensor,
targets_fake: tf.Tensor,
global_step: tf.Variable,
global_step_xd: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
"""
@ -465,9 +483,11 @@ def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: m
return _xd_train_loss
def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscriminator,
def _train_decoder_step(decoder: model.Decoder,
x_discriminator: model.XDiscriminator,
optimizer: tf.train.Optimizer,
targets: tf.Tensor, global_step: tf.Variable,
targets: tf.Tensor,
global_step: tf.Variable,
global_step_decoder: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
"""
@ -504,10 +524,13 @@ def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscrimi
return _decoder_train_loss
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: model.Encoder,
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator,
encoder: model.Encoder,
optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable,
inputs: tf.Tensor,
targets_real: tf.Tensor,
targets_fake: tf.Tensor,
global_step: tf.Variable,
global_step_zd: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
"""
@ -555,9 +578,12 @@ def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: m
return _zd_train_loss
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder, z_discriminator: model.ZDiscriminator,
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
targets: tf.Tensor, global_step: tf.Variable,
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder,
z_discriminator: model.ZDiscriminator,
optimizer: tf.train.Optimizer,
inputs: tf.Tensor,
targets: tf.Tensor,
global_step: tf.Variable,
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Trains the encoder and decoder jointly for one step (one batch).