Added simple auto-encoder training

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-25 14:31:26 +01:00
parent 0eab8cd123
commit 9a08ea3bb9

View File

@ -29,6 +29,7 @@ Attributes:
Functions: Functions:
prepare_training_data(...): prepares the mnist training data prepare_training_data(...): prepares the mnist training data
train(...): trains the AAE models train(...): trains the AAE models
train_simple(...): trains a simple auto-encoder only with reconstruction loss
Todos: Todos:
- fix early stopping - fix early stopping
@ -129,6 +130,274 @@ def prepare_training_data(test_fold_id: int,
return train_dataset, valid_dataset return train_dataset, valid_dataset
def train_simple(dataset: tf.data.Dataset,
iteration: int,
weights_prefix: str,
channels: int = 1,
zsize: int = 32,
lr: float = 0.002,
batch_size: int = 128,
train_epoch: int = 80,
verbose: bool = True,
early_stopping: bool = False) -> None:
"""
Trains aut-encoder for given data set.
This function provides early stopping and creates checkpoints after every
epoch as well as after finishing training (or stopping early). When starting
this function with the same ``iteration`` then the training will try to
continue where it ended last time by restoring a saved checkpoint.
The loss values are provided as scalar summaries. Reconstruction and sample
images are provided as summary images.
Args:
dataset: train dataset
iteration: identifier for the current training run
weights_prefix: prefix for weights directory
channels: number of channels in input image (default: 1)
zsize: size of the intermediary z (default: 32)
lr: initial learning rate (default: 0.002)
batch_size: the size of each batch (default: 128)
train_epoch: number of epochs to train (default: 80)
verbose: if True prints train progress info to console (default: True)
early_stopping: if True the early stopping mechanic is enabled (default: False)
Notes:
The training stops early if for ``GRACE`` number of epochs the loss is not
decreasing. Specifically all individual losses are accounted for and any one
of those not decreasing triggers a ``strike``. If the total loss, which is
a sum of all individual losses, is also not decreasing and has a total
value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is
decreased. If in any epoch afterwards all losses are decreasing the grace
period is reset to ``GRACE``. Lastly the training loop will be stopped early
if the grace counter reaches ``0`` at the end of an epoch.
"""
# non-preserved tensors
y_real = K.ones(batch_size)
sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1)
# non-preserved python variables
enc_dec_lowest_loss = math.inf
total_lowest_loss = math.inf
grace_period = GRACE
# checkpointed tensors and variables
checkpointables = {
'learning_rate_var': K.variable(lr),
}
checkpointables.update({
# get models
'encoder': model.Encoder(zsize),
'decoder': model.Decoder(channels),
# define optimizers
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999),
# global step counter
'epoch_var': K.variable(-1, dtype=tf.int64),
'global_step': tf.train.get_or_create_global_step(),
'global_step_enc_dec': K.variable(0, dtype=tf.int64),
})
# checkpoint
checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/')
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(**checkpointables)
checkpoint.restore(latest_checkpoint)
def _get_last_epoch(epoch_var: tf.Variable, **kwargs) -> int:
return int(epoch_var)
last_epoch = _get_last_epoch(**checkpointables)
previous_epochs = 0
if last_epoch != -1:
previous_epochs = last_epoch + 1
with summary_ops_v2.always_record_summaries():
summary_ops_v2.scalar(name='learning_rate', tensor=checkpointables['learning_rate_var'],
step=checkpointables['global_step'])
for epoch in range(train_epoch - previous_epochs):
_epoch = epoch + previous_epochs
outputs = _train_one_epoch_simple(_epoch, dataset, targets_real=y_real,
verbose=verbose,
**checkpointables)
if verbose:
print((
f"[{_epoch + 1:d}/{train_epoch:d}] - "
f"train time: {outputs['per_epoch_time']:.2f}, "
f"Decoder loss: {outputs['decoder_loss']:.3f}, "
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, "
f"Encoder loss: {outputs['encoder_loss']:.3f}"
))
# save sample image summary
def _save_sample(decoder: model.Decoder, global_step: tf.Variable, **kwargs) -> None:
resultsample = decoder(sample).cpu()
grid = util.prepare_image(resultsample)
summary_ops_v2.image(name='sample', tensor=K.expand_dims(grid, axis=0),
max_images=1, step=global_step)
with summary_ops_v2.always_record_summaries():
_save_sample(**checkpointables)
# save weights at end of epoch
checkpoint.save(checkpoint_prefix)
# check for improvements in error reduction - otherwise early stopping
if early_stopping:
strike = False
total_strike = False
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
outputs['xd_loss'] + outputs['zd_loss']
if total_loss < total_lowest_loss:
total_lowest_loss = total_loss
elif total_loss > TOTAL_LOSS_GRACE_CAP:
total_strike = True
if outputs['encoder_loss'] < encoder_lowest_loss:
encoder_lowest_loss = outputs['encoder_loss']
else:
strike = True
if outputs['decoder_loss'] < decoder_lowest_loss:
decoder_lowest_loss = outputs['decoder_loss']
else:
strike = True
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
enc_dec_lowest_loss = outputs['enc_dec_loss']
else:
strike = True
if outputs['xd_loss'] < xd_lowest_loss:
xd_lowest_loss = outputs['xd_loss']
else:
strike = True
if outputs['zd_loss'] < zd_lowest_loss:
zd_lowest_loss = outputs['zd_loss']
else:
strike = True
if strike and total_strike:
grace_period -= 1
elif strike:
pass
else:
grace_period = GRACE
if grace_period == 0:
break
if verbose:
if grace_period > 0:
print("Training finish!... save model weights")
if grace_period == 0:
print("Training stopped early!... save model weights")
# save trained models
checkpoint.save(checkpoint_prefix)
def _train_one_epoch_simple(epoch: int,
dataset: tf.data.Dataset,
targets_real: tf.Tensor,
verbose: bool,
learning_rate_var: tf.Variable,
decoder: model.Decoder,
encoder: model.Encoder,
enc_dec_optimizer: tf.train.Optimizer,
global_step: tf.Variable,
global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]:
with summary_ops_v2.always_record_summaries():
epoch_var.assign(epoch)
epoch_start_time = time.time()
# define loss variables
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
# update learning rate
if (epoch + 1) % 30 == 0:
learning_rate_var.assign(learning_rate_var.value() / 4)
summary_ops_v2.scalar(name='learning_rate', tensor=learning_rate_var,
step=global_step)
if verbose:
print("learning rate change!")
for x, _ in dataset:
reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder,
decoder=decoder,
optimizer=enc_dec_optimizer,
inputs=x,
global_step_enc_dec=global_step_enc_dec,
global_step=global_step)
enc_dec_loss_avg(reconstruction_loss)
if int(global_step % LOG_FREQUENCY) == 0:
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0)
grid = util.prepare_image(comparison.cpu(), nrow=64)
summary_ops_v2.image(name='reconstruction',
tensor=K.expand_dims(grid, axis=0), max_images=1,
step=global_step)
global_step.assign_add(1)
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
# final losses of epoch
outputs = {
'decoder_loss': decoder_loss_avg.result(False),
'encoder_loss': encoder_loss_avg.result(False),
'enc_dec_loss': enc_dec_loss_avg.result(False),
'per_epoch_time': per_epoch_time,
}
return outputs
def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
optimizer: tf.train.Optimizer,
inputs: tf.Tensor,
global_step: tf.Variable,
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Trains the encoder and decoder jointly for one step (one batch).
:param encoder: instance of encoder model
:param decoder: instance of decoder model
:param optimizer: instance of chosen optimizer
:param inputs: inputs from dataset
:param global_step: the global step variable
:param global_step_enc_dec: global step variable for enc_dec
:return: tuple of reconstruction loss, reconstructed input
"""
with tf.GradientTape() as tape:
z = encoder(inputs)
x_decoded = decoder(z)
reconstruction_loss = tf.losses.log_loss(inputs, x_decoded)
_enc_dec_train_loss = reconstruction_loss
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
encoder.trainable_variables + decoder.trainable_variables)
if int(global_step % LOG_FREQUENCY) == 0:
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
step=global_step)
summary_ops_v2.scalar(name='encoder_decoder_loss', tensor=_enc_dec_train_loss,
step=global_step)
for grad, variable in zip(enc_dec_grads, encoder.trainable_variables + decoder.trainable_variables):
summary_ops_v2.histogram(name='gradients/' + variable.name, tensor=tf.math.l2_normalize(grad),
step=global_step)
summary_ops_v2.histogram(name='variables/' + variable.name, tensor=tf.math.l2_normalize(variable),
step=global_step)
optimizer.apply_gradients(zip(enc_dec_grads,
encoder.trainable_variables + decoder.trainable_variables),
global_step=global_step_enc_dec)
return reconstruction_loss, x_decoded
def train(dataset: tf.data.Dataset, def train(dataset: tf.data.Dataset,
iteration: int, iteration: int,
weights_prefix: str, weights_prefix: str,