From 2d1fee8048243f92cdcb64ea9ee05d07f88a22f1 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 16:34:23 +0100 Subject: [PATCH] Moved training of epoch into separate function Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 334 ++++++++++++----------- 1 file changed, 169 insertions(+), 165 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 50ff5ac..9ddfa4a 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -20,7 +20,7 @@ import math import os import pickle import time -from typing import Callable, Sequence, Tuple +from typing import Callable, Dict, Sequence, Tuple import numpy as np import tensorflow as tf @@ -96,7 +96,7 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota return train_dataset, valid_dataset -def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, +def train(dataset: tf.data.Dataset, iteration: int, weights_prefix: str, channels: int = 1, zsize: int = 32, lr: float = 0.002, batch_size: int = 128, train_epoch: int = 80, @@ -106,7 +106,6 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, :param dataset: train dataset :param iteration: identifier for the current training run - :param result_prefix: prefix for result images :param weights_prefix: prefix for weights directory :param channels: number of channels in input image (default: 1) :param zsize: size of the intermediary z (default: 32) @@ -116,31 +115,14 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, :param verbose: if True prints train progress info to console (default: True) """ - # get models - encoder = Encoder(zsize) - decoder = Decoder(channels) - z_discriminator = ZDiscriminator() - x_discriminator = XDiscriminator() - - # define optimizers - learning_rate_var = k.variable(lr) - decoder_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999) - enc_dec_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999) - z_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999) - x_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999) - - # train + # non-preserved tensors y_real = k.ones(batch_size) y_fake = k.zeros(batch_size) sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1) - + # z generator function z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize) - global_step_decoder = k.variable(0, dtype=tf.int64) - global_step_enc_dec = k.variable(0, dtype=tf.int64) - global_step_xd = k.variable(0, dtype=tf.int64) - global_step_zd = k.variable(0, dtype=tf.int64) - + # non-preserved python variables encoder_lowest_loss = math.inf decoder_lowest_loss = math.inf enc_dec_lowest_loss = math.inf @@ -149,148 +131,62 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, total_lowest_loss = math.inf grace_period = GRACE + # checkpointed tensors and variables + checkpointables = { + 'learning_rate_var': k.variable(lr), + } + checkpointables.update({ + # get models + 'encoder': Encoder(zsize), + 'decoder': Decoder(channels), + 'z_discriminator': ZDiscriminator(), + 'x_discriminator': XDiscriminator(), + # define optimizers + 'decoder_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), + 'enc_dec_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), + 'z_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + beta1=0.5, beta2=0.999), + 'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + beta1=0.5, beta2=0.999), + # global step counter + 'global_step_decoder': k.variable(0, dtype=tf.int64), + 'global_step_enc_dec': k.variable(0, dtype=tf.int64), + 'global_step_xd': k.variable(0, dtype=tf.int64), + 'global_step_zd': k.variable(0, dtype=tf.int64), + }) + + # checkpoint checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/') os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) - checkpoint = tf.train.Checkpoint(encoder=encoder, - decoder=decoder, - z_discriminator=z_discriminator, - x_discriminator=x_discriminator, - decoder_optimizer=decoder_optimizer, - z_discriminator_optimizer=z_discriminator_optimizer, - x_discriminator_optimizer=x_discriminator_optimizer, - enc_dec_optimizer=enc_dec_optimizer, - global_step_decoder=global_step_decoder, - global_step_enc_dec=global_step_enc_dec, - global_step_xd=global_step_xd, - global_step_zd=global_step_zd, - learning_rate_var=learning_rate_var) - if latest_checkpoint is not None: - # if there is a checkpoint in the current training iteration, proceed from there - checkpoint.restore(latest_checkpoint) + checkpoint = tf.train.Checkpoint(**checkpointables) + checkpoint.restore(latest_checkpoint) for epoch in range(train_epoch): - # define loss variables - encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) - decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) - enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) - zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32) - xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32) + outputs = _train_one_epoch(epoch, dataset, targets_real=y_real, + targets_fake=y_fake, z_generator= z_generator, + verbose=verbose, + **checkpointables) - epoch_start_time = time.time() - - # update learning rate - if (epoch + 1) % 30 == 0: - learning_rate_var.assign(learning_rate_var.value() / 4) - if verbose: - print("learning rate change!") - - log_frequency = 10 - batch_iteration = k.variable(0, dtype=tf.int64) - for x, _ in dataset: - # x discriminator - _xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator, - decoder=decoder, - optimizer=x_discriminator_optimizer, - inputs=x, - targets_real=y_real, - targets_fake=y_fake, - global_step=global_step_xd, - z_generator=z_generator) - xd_loss_avg(_xd_train_loss) - - # -------- - # decoder - _decoder_train_loss = train_decoder_step(decoder=decoder, - x_discriminator=x_discriminator, - optimizer=decoder_optimizer, - targets=y_real, - global_step=global_step_decoder, - z_generator=z_generator) - decoder_loss_avg(_decoder_train_loss) - - # --------- - # z discriminator - _zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator, - encoder=encoder, - optimizer=z_discriminator_optimizer, - inputs=x, - targets_real=y_real, - targets_fake=y_fake, - global_step=global_step_zd, - z_generator=z_generator) - zd_loss_avg(_zd_train_loss) - - # ----------- - # encoder + decoder - encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder, - decoder=decoder, - z_discriminator=z_discriminator, - optimizer=enc_dec_optimizer, - inputs=x, - targets=y_real, - global_step=global_step_enc_dec) - enc_dec_loss_avg(reconstruction_loss) - encoder_loss_avg(encoder_loss) - - if int(global_step_decoder % log_frequency) == 0: - # log the losses every log frequency batches - summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec) - summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder) - summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec) - summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd) - summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd) - - if int(batch_iteration) == 0: - directory = 'results' + str(inlier_classes[0]) - if not os.path.exists(directory): - os.makedirs(directory) - comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) - grid = prepare_image(comparison.cpu(), nrow=64) - summary_ops_v2.image(name='reconstruction_' + str(epoch), - tensor=k.expand_dims(grid, axis=0), max_images=1, - step=global_step_decoder) - from PIL import Image - filename = os.path.join(result_prefix, 'reconstruction_' + str(epoch) + '.png') - ndarr = grid.cpu().numpy() - im = Image.fromarray(ndarr) - im.save(filename) - - batch_iteration.assign_add(1) - - epoch_end_time = time.time() - per_epoch_time = epoch_end_time - epoch_start_time - - # final losses of epoch - decoder_loss = decoder_loss_avg.result(False) - encoder_loss = encoder_loss_avg.result(False) - enc_dec_loss = enc_dec_loss_avg.result(False) - xd_loss = xd_loss_avg.result(False) - zd_loss = zd_loss_avg.result(False) if verbose: print(( f"[{epoch + 1:d}/{train_epoch:d}] - " - f"train time: {per_epoch_time:.2f}, " - f"Decoder loss: {decoder_loss:.3f}, " - f"X Discriminator loss: {xd_loss:.3f}, " - f"Z Discriminator loss: {zd_loss:.3f}, " - f"Encoder + Decoder loss: {enc_dec_loss:.3f}, " - f"Encoder loss: {encoder_loss:.3f}" + f"train time: {outputs['per_epoch_time']:.2f}, " + f"Decoder loss: {outputs['decoder_loss']:.3f}, " + f"X Discriminator loss: {outputs['xd_loss']:.3f}, " + f"Z Discriminator loss: {outputs['zd_loss']:.3f}, " + f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, " + f"Encoder loss: {outputs['encoder_loss']:.3f}" )) - # save sample image - resultsample = decoder(sample).cpu() - directory = 'results' + str(inlier_classes[0]) - os.makedirs(directory, exist_ok=True) - grid = prepare_image(resultsample) - summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), - max_images=1, step=global_step_decoder) - from PIL import Image - filename = os.path.join(result_prefix, 'sample_' + str(epoch) + '.png') - ndarr = grid.cpu().numpy() - im = Image.fromarray(ndarr) - im.save(filename) + # save sample image summary + def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable) -> None: + resultsample = decoder(sample).cpu() + grid = prepare_image(resultsample) + summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), + max_images=1, step=global_step_decoder) + _save_sample(**checkpointables) # save weights at end of epoch checkpoint.save(checkpoint_prefix) @@ -298,29 +194,30 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, # check for improvements in error reduction - otherwise early stopping strike = False total_strike = False - total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss + total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ + outputs['xd_loss'] + outputs['zd_loss'] if total_loss < total_lowest_loss: total_lowest_loss = total_loss elif total_loss > 6: total_strike = True - if encoder_loss < encoder_lowest_loss: - encoder_lowest_loss = encoder_loss + if outputs['encoder_loss'] < encoder_lowest_loss: + encoder_lowest_loss = outputs['encoder_loss'] else: strike = True - if decoder_loss < decoder_lowest_loss: - decoder_lowest_loss = decoder_loss + if outputs['decoder_loss'] < decoder_lowest_loss: + decoder_lowest_loss = outputs['decoder_loss'] else: strike = True - if enc_dec_loss < enc_dec_lowest_loss: - enc_dec_lowest_loss = enc_dec_loss + if outputs['enc_dec_loss'] < enc_dec_lowest_loss: + enc_dec_lowest_loss = outputs['enc_dec_loss'] else: strike = True - if xd_loss < xd_lowest_loss: - xd_lowest_loss = xd_loss + if outputs['xd_loss'] < xd_lowest_loss: + xd_lowest_loss = outputs['xd_loss'] else: strike = True - if zd_loss < zd_lowest_loss: - zd_lowest_loss = zd_loss + if outputs['zd_loss'] < zd_lowest_loss: + zd_lowest_loss = outputs['zd_loss'] else: strike = True @@ -344,6 +241,114 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, checkpoint.save(checkpoint_prefix) +def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor, + verbose: bool, + targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable], + learning_rate_var: tf.Variable, + decoder: Decoder, encoder: Encoder, x_discriminator: XDiscriminator, + z_discriminator: ZDiscriminator, decoder_optimizer: tf.train.Optimizer, + x_discriminator_optimizer: tf.train.Optimizer, + z_discriminator_optimizer: tf.train.Optimizer, + enc_dec_optimizer: tf.train.Optimizer, + global_step_xd: tf.Variable, global_step_zd: tf.Variable, + global_step_decoder: tf.Variable, + global_step_enc_dec: tf.Variable) -> Dict[str, float]: + + epoch_start_time = time.time() + # define loss variables + encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) + decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) + enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) + zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32) + xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32) + + # update learning rate + if (epoch + 1) % 30 == 0: + learning_rate_var.assign(learning_rate_var.value() / 4) + if verbose: + print("learning rate change!") + + log_frequency = 10 + batch_iteration = k.variable(0, dtype=tf.int64) + for x, _ in dataset: + # x discriminator + _xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator, + decoder=decoder, + optimizer=x_discriminator_optimizer, + inputs=x, + targets_real=targets_real, + targets_fake=targets_fake, + global_step=global_step_xd, + z_generator=z_generator) + xd_loss_avg(_xd_train_loss) + + # -------- + # decoder + _decoder_train_loss = train_decoder_step(decoder=decoder, + x_discriminator=x_discriminator, + optimizer=decoder_optimizer, + targets=targets_real, + global_step=global_step_decoder, + z_generator=z_generator) + decoder_loss_avg(_decoder_train_loss) + + # --------- + # z discriminator + _zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator, + encoder=encoder, + optimizer=z_discriminator_optimizer, + inputs=x, + targets_real=targets_real, + targets_fake=targets_fake, + global_step=global_step_zd, + z_generator=z_generator) + zd_loss_avg(_zd_train_loss) + + # ----------- + # encoder + decoder + encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder, + decoder=decoder, + z_discriminator=z_discriminator, + optimizer=enc_dec_optimizer, + inputs=x, + targets=targets_real, + global_step=global_step_enc_dec) + enc_dec_loss_avg(reconstruction_loss) + encoder_loss_avg(encoder_loss) + + if int(global_step_decoder % log_frequency) == 0: + # log the losses every log frequency batches + summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec) + summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder) + summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec) + summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd) + summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd) + + if int(batch_iteration) == 0: + comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) + grid = prepare_image(comparison.cpu(), nrow=64) + summary_ops_v2.image(name='reconstruction_' + str(epoch), + tensor=k.expand_dims(grid, axis=0), max_images=1, + step=global_step_decoder) + + batch_iteration.assign_add(1) + + epoch_end_time = time.time() + per_epoch_time = epoch_end_time - epoch_start_time + + # final losses of epoch + outputs = { + 'decoder_loss': decoder_loss_avg.result(False), + 'encoder_loss': encoder_loss_avg.result(False), + 'enc_dec_loss': enc_dec_loss_avg.result(False), + 'xd_loss': xd_loss_avg.result(False), + 'zd_loss': zd_loss_avg.result(False), + 'per_epoch_time': per_epoch_time, + } + + return outputs + + def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, optimizer: tf.train.Optimizer, inputs: tf.Tensor, targets_real: tf.Tensor, @@ -525,5 +530,4 @@ if __name__ == "__main__": './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): train(dataset=train_dataset, iteration=iteration, - result_prefix='results' + str(inlier_classes[0]) + '/', weights_prefix='weights/' + str(inlier_classes[0]) + '/')