Reformatted code to ensure all parameters are on their own line
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -62,9 +62,12 @@ TOTAL_LOSS_GRACE_CAP: int = 6
|
||||
LOG_FREQUENCY: int = 10
|
||||
|
||||
|
||||
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||
def prepare_training_data(test_fold_id: int,
|
||||
inlier_classes: Sequence[int],
|
||||
total_classes: int,
|
||||
fold_prefix: str = 'data/data_fold_',
|
||||
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||
batch_size: int = 128,
|
||||
folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||
"""
|
||||
Prepares the MNIST training data.
|
||||
|
||||
@ -126,11 +129,16 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def train(dataset: tf.data.Dataset, iteration: int,
|
||||
def train(dataset: tf.data.Dataset,
|
||||
iteration: int,
|
||||
weights_prefix: str,
|
||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
verbose: bool = True, early_stopping: bool = False) -> None:
|
||||
channels: int = 1,
|
||||
zsize: int = 32,
|
||||
lr: float = 0.002,
|
||||
batch_size: int = 128,
|
||||
train_epoch: int = 80,
|
||||
verbose: bool = True,
|
||||
early_stopping: bool = False) -> None:
|
||||
"""
|
||||
Trains AAE for given data set.
|
||||
|
||||
@ -141,16 +149,6 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
The loss values are provided as scalar summaries. Reconstruction and sample
|
||||
images are provided as summary images.
|
||||
|
||||
Notes:
|
||||
The training stops early if for ``GRACE`` number of epochs the loss is not
|
||||
decreasing. Specifically all individual losses are accounted for and any one
|
||||
of those not decreasing triggers a ``strike``. If the total loss, which is
|
||||
a sum of all individual losses, is also not decreasing and has a total
|
||||
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
|
||||
decreased. If in any epoch afterwards all losses are decreasing the grace
|
||||
period is reset to ``GRACE``. Lastly the training loop will be stopped early
|
||||
if the grace counter reaches ``0`` at the end of an epoch.
|
||||
|
||||
Args:
|
||||
dataset: train dataset
|
||||
iteration: identifier for the current training run
|
||||
@ -162,6 +160,16 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
train_epoch: number of epochs to train (default: 80)
|
||||
verbose: if True prints train progress info to console (default: True)
|
||||
early_stopping: if True the early stopping mechanic is enabled (default: False)
|
||||
|
||||
Notes:
|
||||
The training stops early if for ``GRACE`` number of epochs the loss is not
|
||||
decreasing. Specifically all individual losses are accounted for and any one
|
||||
of those not decreasing triggers a ``strike``. If the total loss, which is
|
||||
a sum of all individual losses, is also not decreasing and has a total
|
||||
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
|
||||
decreased. If in any epoch afterwards all losses are decreasing the grace
|
||||
period is reset to ``GRACE``. Lastly the training loop will be stopped early
|
||||
if the grace counter reaches ``0`` at the end of an epoch.
|
||||
"""
|
||||
|
||||
# non-preserved tensors
|
||||
@ -308,17 +316,24 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
|
||||
|
||||
def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor,
|
||||
def _train_one_epoch(epoch: int,
|
||||
dataset: tf.data.Dataset,
|
||||
targets_real: tf.Tensor,
|
||||
verbose: bool,
|
||||
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable],
|
||||
targets_fake: tf.Tensor,
|
||||
z_generator: Callable[[], tf.Variable],
|
||||
learning_rate_var: tf.Variable,
|
||||
decoder: model.Decoder, encoder: model.Encoder, x_discriminator: model.XDiscriminator,
|
||||
z_discriminator: model.ZDiscriminator, decoder_optimizer: tf.train.Optimizer,
|
||||
decoder: model.Decoder,
|
||||
encoder: model.Encoder,
|
||||
x_discriminator: model.XDiscriminator,
|
||||
z_discriminator: model.ZDiscriminator,
|
||||
decoder_optimizer: tf.train.Optimizer,
|
||||
x_discriminator_optimizer: tf.train.Optimizer,
|
||||
z_discriminator_optimizer: tf.train.Optimizer,
|
||||
enc_dec_optimizer: tf.train.Optimizer,
|
||||
global_step: tf.Variable,
|
||||
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
|
||||
global_step_xd: tf.Variable,
|
||||
global_step_zd: tf.Variable,
|
||||
global_step_decoder: tf.Variable,
|
||||
global_step_enc_dec: tf.Variable,
|
||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||
@ -415,10 +430,13 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
return outputs
|
||||
|
||||
|
||||
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: model.Decoder,
|
||||
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator,
|
||||
decoder: model.Decoder,
|
||||
optimizer: tf.train.Optimizer,
|
||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||
inputs: tf.Tensor,
|
||||
targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor,
|
||||
global_step: tf.Variable,
|
||||
global_step_xd: tf.Variable,
|
||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||
"""
|
||||
@ -465,9 +483,11 @@ def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: m
|
||||
return _xd_train_loss
|
||||
|
||||
|
||||
def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscriminator,
|
||||
def _train_decoder_step(decoder: model.Decoder,
|
||||
x_discriminator: model.XDiscriminator,
|
||||
optimizer: tf.train.Optimizer,
|
||||
targets: tf.Tensor, global_step: tf.Variable,
|
||||
targets: tf.Tensor,
|
||||
global_step: tf.Variable,
|
||||
global_step_decoder: tf.Variable,
|
||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||
"""
|
||||
@ -504,10 +524,13 @@ def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscrimi
|
||||
return _decoder_train_loss
|
||||
|
||||
|
||||
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: model.Encoder,
|
||||
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator,
|
||||
encoder: model.Encoder,
|
||||
optimizer: tf.train.Optimizer,
|
||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||
inputs: tf.Tensor,
|
||||
targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor,
|
||||
global_step: tf.Variable,
|
||||
global_step_zd: tf.Variable,
|
||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||
"""
|
||||
@ -555,9 +578,12 @@ def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: m
|
||||
return _zd_train_loss
|
||||
|
||||
|
||||
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder, z_discriminator: model.ZDiscriminator,
|
||||
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
||||
targets: tf.Tensor, global_step: tf.Variable,
|
||||
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder,
|
||||
z_discriminator: model.ZDiscriminator,
|
||||
optimizer: tf.train.Optimizer,
|
||||
inputs: tf.Tensor,
|
||||
targets: tf.Tensor,
|
||||
global_step: tf.Variable,
|
||||
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Trains the encoder and decoder jointly for one step (one batch).
|
||||
|
||||
Reference in New Issue
Block a user