Reformatted code to ensure all parameters are on their own line
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -62,9 +62,12 @@ TOTAL_LOSS_GRACE_CAP: int = 6
|
|||||||
LOG_FREQUENCY: int = 10
|
LOG_FREQUENCY: int = 10
|
||||||
|
|
||||||
|
|
||||||
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def prepare_training_data(test_fold_id: int,
|
||||||
|
inlier_classes: Sequence[int],
|
||||||
|
total_classes: int,
|
||||||
fold_prefix: str = 'data/data_fold_',
|
fold_prefix: str = 'data/data_fold_',
|
||||||
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
batch_size: int = 128,
|
||||||
|
folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||||
"""
|
"""
|
||||||
Prepares the MNIST training data.
|
Prepares the MNIST training data.
|
||||||
|
|
||||||
@ -126,11 +129,16 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
|||||||
return train_dataset, valid_dataset
|
return train_dataset, valid_dataset
|
||||||
|
|
||||||
|
|
||||||
def train(dataset: tf.data.Dataset, iteration: int,
|
def train(dataset: tf.data.Dataset,
|
||||||
|
iteration: int,
|
||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
channels: int = 1,
|
||||||
batch_size: int = 128, train_epoch: int = 80,
|
zsize: int = 32,
|
||||||
verbose: bool = True, early_stopping: bool = False) -> None:
|
lr: float = 0.002,
|
||||||
|
batch_size: int = 128,
|
||||||
|
train_epoch: int = 80,
|
||||||
|
verbose: bool = True,
|
||||||
|
early_stopping: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Trains AAE for given data set.
|
Trains AAE for given data set.
|
||||||
|
|
||||||
@ -141,16 +149,6 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
The loss values are provided as scalar summaries. Reconstruction and sample
|
The loss values are provided as scalar summaries. Reconstruction and sample
|
||||||
images are provided as summary images.
|
images are provided as summary images.
|
||||||
|
|
||||||
Notes:
|
|
||||||
The training stops early if for ``GRACE`` number of epochs the loss is not
|
|
||||||
decreasing. Specifically all individual losses are accounted for and any one
|
|
||||||
of those not decreasing triggers a ``strike``. If the total loss, which is
|
|
||||||
a sum of all individual losses, is also not decreasing and has a total
|
|
||||||
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
|
|
||||||
decreased. If in any epoch afterwards all losses are decreasing the grace
|
|
||||||
period is reset to ``GRACE``. Lastly the training loop will be stopped early
|
|
||||||
if the grace counter reaches ``0`` at the end of an epoch.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: train dataset
|
dataset: train dataset
|
||||||
iteration: identifier for the current training run
|
iteration: identifier for the current training run
|
||||||
@ -162,6 +160,16 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
train_epoch: number of epochs to train (default: 80)
|
train_epoch: number of epochs to train (default: 80)
|
||||||
verbose: if True prints train progress info to console (default: True)
|
verbose: if True prints train progress info to console (default: True)
|
||||||
early_stopping: if True the early stopping mechanic is enabled (default: False)
|
early_stopping: if True the early stopping mechanic is enabled (default: False)
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The training stops early if for ``GRACE`` number of epochs the loss is not
|
||||||
|
decreasing. Specifically all individual losses are accounted for and any one
|
||||||
|
of those not decreasing triggers a ``strike``. If the total loss, which is
|
||||||
|
a sum of all individual losses, is also not decreasing and has a total
|
||||||
|
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
|
||||||
|
decreased. If in any epoch afterwards all losses are decreasing the grace
|
||||||
|
period is reset to ``GRACE``. Lastly the training loop will be stopped early
|
||||||
|
if the grace counter reaches ``0`` at the end of an epoch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# non-preserved tensors
|
# non-preserved tensors
|
||||||
@ -308,17 +316,24 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
checkpoint.save(checkpoint_prefix)
|
checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
|
|
||||||
def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor,
|
def _train_one_epoch(epoch: int,
|
||||||
|
dataset: tf.data.Dataset,
|
||||||
|
targets_real: tf.Tensor,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable],
|
targets_fake: tf.Tensor,
|
||||||
|
z_generator: Callable[[], tf.Variable],
|
||||||
learning_rate_var: tf.Variable,
|
learning_rate_var: tf.Variable,
|
||||||
decoder: model.Decoder, encoder: model.Encoder, x_discriminator: model.XDiscriminator,
|
decoder: model.Decoder,
|
||||||
z_discriminator: model.ZDiscriminator, decoder_optimizer: tf.train.Optimizer,
|
encoder: model.Encoder,
|
||||||
|
x_discriminator: model.XDiscriminator,
|
||||||
|
z_discriminator: model.ZDiscriminator,
|
||||||
|
decoder_optimizer: tf.train.Optimizer,
|
||||||
x_discriminator_optimizer: tf.train.Optimizer,
|
x_discriminator_optimizer: tf.train.Optimizer,
|
||||||
z_discriminator_optimizer: tf.train.Optimizer,
|
z_discriminator_optimizer: tf.train.Optimizer,
|
||||||
enc_dec_optimizer: tf.train.Optimizer,
|
enc_dec_optimizer: tf.train.Optimizer,
|
||||||
global_step: tf.Variable,
|
global_step: tf.Variable,
|
||||||
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
|
global_step_xd: tf.Variable,
|
||||||
|
global_step_zd: tf.Variable,
|
||||||
global_step_decoder: tf.Variable,
|
global_step_decoder: tf.Variable,
|
||||||
global_step_enc_dec: tf.Variable,
|
global_step_enc_dec: tf.Variable,
|
||||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||||
@ -415,10 +430,13 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: model.Decoder,
|
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator,
|
||||||
|
decoder: model.Decoder,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_real: tf.Tensor,
|
||||||
|
targets_fake: tf.Tensor,
|
||||||
|
global_step: tf.Variable,
|
||||||
global_step_xd: tf.Variable,
|
global_step_xd: tf.Variable,
|
||||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -465,9 +483,11 @@ def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: m
|
|||||||
return _xd_train_loss
|
return _xd_train_loss
|
||||||
|
|
||||||
|
|
||||||
def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscriminator,
|
def _train_decoder_step(decoder: model.Decoder,
|
||||||
|
x_discriminator: model.XDiscriminator,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
targets: tf.Tensor, global_step: tf.Variable,
|
targets: tf.Tensor,
|
||||||
|
global_step: tf.Variable,
|
||||||
global_step_decoder: tf.Variable,
|
global_step_decoder: tf.Variable,
|
||||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -504,10 +524,13 @@ def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscrimi
|
|||||||
return _decoder_train_loss
|
return _decoder_train_loss
|
||||||
|
|
||||||
|
|
||||||
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: model.Encoder,
|
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator,
|
||||||
|
encoder: model.Encoder,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_real: tf.Tensor,
|
||||||
|
targets_fake: tf.Tensor,
|
||||||
|
global_step: tf.Variable,
|
||||||
global_step_zd: tf.Variable,
|
global_step_zd: tf.Variable,
|
||||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -555,9 +578,12 @@ def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: m
|
|||||||
return _zd_train_loss
|
return _zd_train_loss
|
||||||
|
|
||||||
|
|
||||||
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder, z_discriminator: model.ZDiscriminator,
|
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder,
|
||||||
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
z_discriminator: model.ZDiscriminator,
|
||||||
targets: tf.Tensor, global_step: tf.Variable,
|
optimizer: tf.train.Optimizer,
|
||||||
|
inputs: tf.Tensor,
|
||||||
|
targets: tf.Tensor,
|
||||||
|
global_step: tf.Variable,
|
||||||
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||||
"""
|
"""
|
||||||
Trains the encoder and decoder jointly for one step (one batch).
|
Trains the encoder and decoder jointly for one step (one batch).
|
||||||
|
|||||||
Reference in New Issue
Block a user