From 9a08ea3bb995d780ec6fefec12e2a2ba3b90f892 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Mon, 25 Mar 2019 14:31:26 +0100 Subject: [PATCH] Added simple auto-encoder training Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 269 +++++++++++++++++++++++ 1 file changed, 269 insertions(+) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 6c2f849..06bf50e 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -29,6 +29,7 @@ Attributes: Functions: prepare_training_data(...): prepares the mnist training data train(...): trains the AAE models + train_simple(...): trains a simple auto-encoder only with reconstruction loss Todos: - fix early stopping @@ -129,6 +130,274 @@ def prepare_training_data(test_fold_id: int, return train_dataset, valid_dataset +def train_simple(dataset: tf.data.Dataset, + iteration: int, + weights_prefix: str, + channels: int = 1, + zsize: int = 32, + lr: float = 0.002, + batch_size: int = 128, + train_epoch: int = 80, + verbose: bool = True, + early_stopping: bool = False) -> None: + """ + Trains aut-encoder for given data set. + + This function provides early stopping and creates checkpoints after every + epoch as well as after finishing training (or stopping early). When starting + this function with the same ``iteration`` then the training will try to + continue where it ended last time by restoring a saved checkpoint. + The loss values are provided as scalar summaries. Reconstruction and sample + images are provided as summary images. + + Args: + dataset: train dataset + iteration: identifier for the current training run + weights_prefix: prefix for weights directory + channels: number of channels in input image (default: 1) + zsize: size of the intermediary z (default: 32) + lr: initial learning rate (default: 0.002) + batch_size: the size of each batch (default: 128) + train_epoch: number of epochs to train (default: 80) + verbose: if True prints train progress info to console (default: True) + early_stopping: if True the early stopping mechanic is enabled (default: False) + + Notes: + The training stops early if for ``GRACE`` number of epochs the loss is not + decreasing. Specifically all individual losses are accounted for and any one + of those not decreasing triggers a ``strike``. If the total loss, which is + a sum of all individual losses, is also not decreasing and has a total + value of more than ``TOTAL_LOSS_GRACE_CAP``, the counter for the remaining grace period is + decreased. If in any epoch afterwards all losses are decreasing the grace + period is reset to ``GRACE``. Lastly the training loop will be stopped early + if the grace counter reaches ``0`` at the end of an epoch. + """ + + # non-preserved tensors + y_real = K.ones(batch_size) + sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1) + + # non-preserved python variables + enc_dec_lowest_loss = math.inf + total_lowest_loss = math.inf + grace_period = GRACE + + # checkpointed tensors and variables + checkpointables = { + 'learning_rate_var': K.variable(lr), + } + checkpointables.update({ + # get models + 'encoder': model.Encoder(zsize), + 'decoder': model.Decoder(channels), + # define optimizers + 'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + beta1=0.5, beta2=0.999), + # global step counter + 'epoch_var': K.variable(-1, dtype=tf.int64), + 'global_step': tf.train.get_or_create_global_step(), + 'global_step_enc_dec': K.variable(0, dtype=tf.int64), + }) + + # checkpoint + checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/') + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) + checkpoint = tf.train.Checkpoint(**checkpointables) + checkpoint.restore(latest_checkpoint) + + def _get_last_epoch(epoch_var: tf.Variable, **kwargs) -> int: + return int(epoch_var) + + last_epoch = _get_last_epoch(**checkpointables) + previous_epochs = 0 + if last_epoch != -1: + previous_epochs = last_epoch + 1 + + with summary_ops_v2.always_record_summaries(): + summary_ops_v2.scalar(name='learning_rate', tensor=checkpointables['learning_rate_var'], + step=checkpointables['global_step']) + + for epoch in range(train_epoch - previous_epochs): + _epoch = epoch + previous_epochs + outputs = _train_one_epoch_simple(_epoch, dataset, targets_real=y_real, + verbose=verbose, + **checkpointables) + + if verbose: + print(( + f"[{_epoch + 1:d}/{train_epoch:d}] - " + f"train time: {outputs['per_epoch_time']:.2f}, " + f"Decoder loss: {outputs['decoder_loss']:.3f}, " + f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, " + f"Encoder loss: {outputs['encoder_loss']:.3f}" + )) + + # save sample image summary + def _save_sample(decoder: model.Decoder, global_step: tf.Variable, **kwargs) -> None: + resultsample = decoder(sample).cpu() + grid = util.prepare_image(resultsample) + summary_ops_v2.image(name='sample', tensor=K.expand_dims(grid, axis=0), + max_images=1, step=global_step) + + with summary_ops_v2.always_record_summaries(): + _save_sample(**checkpointables) + + # save weights at end of epoch + checkpoint.save(checkpoint_prefix) + + # check for improvements in error reduction - otherwise early stopping + if early_stopping: + strike = False + total_strike = False + total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ + outputs['xd_loss'] + outputs['zd_loss'] + if total_loss < total_lowest_loss: + total_lowest_loss = total_loss + elif total_loss > TOTAL_LOSS_GRACE_CAP: + total_strike = True + if outputs['encoder_loss'] < encoder_lowest_loss: + encoder_lowest_loss = outputs['encoder_loss'] + else: + strike = True + if outputs['decoder_loss'] < decoder_lowest_loss: + decoder_lowest_loss = outputs['decoder_loss'] + else: + strike = True + if outputs['enc_dec_loss'] < enc_dec_lowest_loss: + enc_dec_lowest_loss = outputs['enc_dec_loss'] + else: + strike = True + if outputs['xd_loss'] < xd_lowest_loss: + xd_lowest_loss = outputs['xd_loss'] + else: + strike = True + if outputs['zd_loss'] < zd_lowest_loss: + zd_lowest_loss = outputs['zd_loss'] + else: + strike = True + + if strike and total_strike: + grace_period -= 1 + elif strike: + pass + else: + grace_period = GRACE + + if grace_period == 0: + break + + if verbose: + if grace_period > 0: + print("Training finish!... save model weights") + if grace_period == 0: + print("Training stopped early!... save model weights") + + # save trained models + checkpoint.save(checkpoint_prefix) + + +def _train_one_epoch_simple(epoch: int, + dataset: tf.data.Dataset, + targets_real: tf.Tensor, + verbose: bool, + learning_rate_var: tf.Variable, + decoder: model.Decoder, + encoder: model.Encoder, + enc_dec_optimizer: tf.train.Optimizer, + global_step: tf.Variable, + global_step_enc_dec: tf.Variable, + epoch_var: tf.Variable) -> Dict[str, float]: + with summary_ops_v2.always_record_summaries(): + epoch_var.assign(epoch) + epoch_start_time = time.time() + # define loss variables + encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) + decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) + enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) + + # update learning rate + if (epoch + 1) % 30 == 0: + learning_rate_var.assign(learning_rate_var.value() / 4) + summary_ops_v2.scalar(name='learning_rate', tensor=learning_rate_var, + step=global_step) + if verbose: + print("learning rate change!") + + for x, _ in dataset: + reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder, + decoder=decoder, + optimizer=enc_dec_optimizer, + inputs=x, + global_step_enc_dec=global_step_enc_dec, + global_step=global_step) + enc_dec_loss_avg(reconstruction_loss) + + if int(global_step % LOG_FREQUENCY) == 0: + comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0) + grid = util.prepare_image(comparison.cpu(), nrow=64) + summary_ops_v2.image(name='reconstruction', + tensor=K.expand_dims(grid, axis=0), max_images=1, + step=global_step) + global_step.assign_add(1) + + epoch_end_time = time.time() + per_epoch_time = epoch_end_time - epoch_start_time + + # final losses of epoch + outputs = { + 'decoder_loss': decoder_loss_avg.result(False), + 'encoder_loss': encoder_loss_avg.result(False), + 'enc_dec_loss': enc_dec_loss_avg.result(False), + 'per_epoch_time': per_epoch_time, + } + + return outputs + + +def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, + optimizer: tf.train.Optimizer, + inputs: tf.Tensor, + global_step: tf.Variable, + global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Trains the encoder and decoder jointly for one step (one batch). + + :param encoder: instance of encoder model + :param decoder: instance of decoder model + :param optimizer: instance of chosen optimizer + :param inputs: inputs from dataset + :param global_step: the global step variable + :param global_step_enc_dec: global step variable for enc_dec + :return: tuple of reconstruction loss, reconstructed input + """ + with tf.GradientTape() as tape: + z = encoder(inputs) + x_decoded = decoder(z) + + reconstruction_loss = tf.losses.log_loss(inputs, x_decoded) + _enc_dec_train_loss = reconstruction_loss + + enc_dec_grads = tape.gradient(_enc_dec_train_loss, + encoder.trainable_variables + decoder.trainable_variables) + if int(global_step % LOG_FREQUENCY) == 0: + summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss, + step=global_step) + summary_ops_v2.scalar(name='encoder_decoder_loss', tensor=_enc_dec_train_loss, + step=global_step) + for grad, variable in zip(enc_dec_grads, encoder.trainable_variables + decoder.trainable_variables): + summary_ops_v2.histogram(name='gradients/' + variable.name, tensor=tf.math.l2_normalize(grad), + step=global_step) + summary_ops_v2.histogram(name='variables/' + variable.name, tensor=tf.math.l2_normalize(variable), + step=global_step) + optimizer.apply_gradients(zip(enc_dec_grads, + encoder.trainable_variables + decoder.trainable_variables), + global_step=global_step_enc_dec) + + return reconstruction_loss, x_decoded + + def train(dataset: tf.data.Dataset, iteration: int, weights_prefix: str,