Moved training of epoch into separate function

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 16:34:23 +01:00
parent 6f36aa7faf
commit 2d1fee8048

View File

@ -20,7 +20,7 @@ import math
import os import os
import pickle import pickle
import time import time
from typing import Callable, Sequence, Tuple from typing import Callable, Dict, Sequence, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -96,7 +96,7 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
return train_dataset, valid_dataset return train_dataset, valid_dataset
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, def train(dataset: tf.data.Dataset, iteration: int,
weights_prefix: str, weights_prefix: str,
channels: int = 1, zsize: int = 32, lr: float = 0.002, channels: int = 1, zsize: int = 32, lr: float = 0.002,
batch_size: int = 128, train_epoch: int = 80, batch_size: int = 128, train_epoch: int = 80,
@ -106,7 +106,6 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
:param dataset: train dataset :param dataset: train dataset
:param iteration: identifier for the current training run :param iteration: identifier for the current training run
:param result_prefix: prefix for result images
:param weights_prefix: prefix for weights directory :param weights_prefix: prefix for weights directory
:param channels: number of channels in input image (default: 1) :param channels: number of channels in input image (default: 1)
:param zsize: size of the intermediary z (default: 32) :param zsize: size of the intermediary z (default: 32)
@ -116,31 +115,14 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
:param verbose: if True prints train progress info to console (default: True) :param verbose: if True prints train progress info to console (default: True)
""" """
# get models # non-preserved tensors
encoder = Encoder(zsize)
decoder = Decoder(channels)
z_discriminator = ZDiscriminator()
x_discriminator = XDiscriminator()
# define optimizers
learning_rate_var = k.variable(lr)
decoder_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
enc_dec_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
z_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
x_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
# train
y_real = k.ones(batch_size) y_real = k.ones(batch_size)
y_fake = k.zeros(batch_size) y_fake = k.zeros(batch_size)
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1) sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
# z generator function
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize) z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
global_step_decoder = k.variable(0, dtype=tf.int64) # non-preserved python variables
global_step_enc_dec = k.variable(0, dtype=tf.int64)
global_step_xd = k.variable(0, dtype=tf.int64)
global_step_zd = k.variable(0, dtype=tf.int64)
encoder_lowest_loss = math.inf encoder_lowest_loss = math.inf
decoder_lowest_loss = math.inf decoder_lowest_loss = math.inf
enc_dec_lowest_loss = math.inf enc_dec_lowest_loss = math.inf
@ -149,148 +131,62 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
total_lowest_loss = math.inf total_lowest_loss = math.inf
grace_period = GRACE grace_period = GRACE
# checkpointed tensors and variables
checkpointables = {
'learning_rate_var': k.variable(lr),
}
checkpointables.update({
# get models
'encoder': Encoder(zsize),
'decoder': Decoder(channels),
'z_discriminator': ZDiscriminator(),
'x_discriminator': XDiscriminator(),
# define optimizers
'decoder_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
'enc_dec_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
'z_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999),
'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999),
# global step counter
'global_step_decoder': k.variable(0, dtype=tf.int64),
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
'global_step_xd': k.variable(0, dtype=tf.int64),
'global_step_zd': k.variable(0, dtype=tf.int64),
})
# checkpoint
checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/') checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/')
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(encoder=encoder, checkpoint = tf.train.Checkpoint(**checkpointables)
decoder=decoder, checkpoint.restore(latest_checkpoint)
z_discriminator=z_discriminator,
x_discriminator=x_discriminator,
decoder_optimizer=decoder_optimizer,
z_discriminator_optimizer=z_discriminator_optimizer,
x_discriminator_optimizer=x_discriminator_optimizer,
enc_dec_optimizer=enc_dec_optimizer,
global_step_decoder=global_step_decoder,
global_step_enc_dec=global_step_enc_dec,
global_step_xd=global_step_xd,
global_step_zd=global_step_zd,
learning_rate_var=learning_rate_var)
if latest_checkpoint is not None:
# if there is a checkpoint in the current training iteration, proceed from there
checkpoint.restore(latest_checkpoint)
for epoch in range(train_epoch): for epoch in range(train_epoch):
# define loss variables outputs = _train_one_epoch(epoch, dataset, targets_real=y_real,
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) targets_fake=y_fake, z_generator= z_generator,
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) verbose=verbose,
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) **checkpointables)
zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32)
xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32)
epoch_start_time = time.time()
# update learning rate
if (epoch + 1) % 30 == 0:
learning_rate_var.assign(learning_rate_var.value() / 4)
if verbose:
print("learning rate change!")
log_frequency = 10
batch_iteration = k.variable(0, dtype=tf.int64)
for x, _ in dataset:
# x discriminator
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
decoder=decoder,
optimizer=x_discriminator_optimizer,
inputs=x,
targets_real=y_real,
targets_fake=y_fake,
global_step=global_step_xd,
z_generator=z_generator)
xd_loss_avg(_xd_train_loss)
# --------
# decoder
_decoder_train_loss = train_decoder_step(decoder=decoder,
x_discriminator=x_discriminator,
optimizer=decoder_optimizer,
targets=y_real,
global_step=global_step_decoder,
z_generator=z_generator)
decoder_loss_avg(_decoder_train_loss)
# ---------
# z discriminator
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
encoder=encoder,
optimizer=z_discriminator_optimizer,
inputs=x,
targets_real=y_real,
targets_fake=y_fake,
global_step=global_step_zd,
z_generator=z_generator)
zd_loss_avg(_zd_train_loss)
# -----------
# encoder + decoder
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
decoder=decoder,
z_discriminator=z_discriminator,
optimizer=enc_dec_optimizer,
inputs=x,
targets=y_real,
global_step=global_step_enc_dec)
enc_dec_loss_avg(reconstruction_loss)
encoder_loss_avg(encoder_loss)
if int(global_step_decoder % log_frequency) == 0:
# log the losses every log frequency batches
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec)
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder)
summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec)
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd)
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd)
if int(batch_iteration) == 0:
directory = 'results' + str(inlier_classes[0])
if not os.path.exists(directory):
os.makedirs(directory)
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
grid = prepare_image(comparison.cpu(), nrow=64)
summary_ops_v2.image(name='reconstruction_' + str(epoch),
tensor=k.expand_dims(grid, axis=0), max_images=1,
step=global_step_decoder)
from PIL import Image
filename = os.path.join(result_prefix, 'reconstruction_' + str(epoch) + '.png')
ndarr = grid.cpu().numpy()
im = Image.fromarray(ndarr)
im.save(filename)
batch_iteration.assign_add(1)
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
# final losses of epoch
decoder_loss = decoder_loss_avg.result(False)
encoder_loss = encoder_loss_avg.result(False)
enc_dec_loss = enc_dec_loss_avg.result(False)
xd_loss = xd_loss_avg.result(False)
zd_loss = zd_loss_avg.result(False)
if verbose: if verbose:
print(( print((
f"[{epoch + 1:d}/{train_epoch:d}] - " f"[{epoch + 1:d}/{train_epoch:d}] - "
f"train time: {per_epoch_time:.2f}, " f"train time: {outputs['per_epoch_time']:.2f}, "
f"Decoder loss: {decoder_loss:.3f}, " f"Decoder loss: {outputs['decoder_loss']:.3f}, "
f"X Discriminator loss: {xd_loss:.3f}, " f"X Discriminator loss: {outputs['xd_loss']:.3f}, "
f"Z Discriminator loss: {zd_loss:.3f}, " f"Z Discriminator loss: {outputs['zd_loss']:.3f}, "
f"Encoder + Decoder loss: {enc_dec_loss:.3f}, " f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, "
f"Encoder loss: {encoder_loss:.3f}" f"Encoder loss: {outputs['encoder_loss']:.3f}"
)) ))
# save sample image # save sample image summary
resultsample = decoder(sample).cpu() def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable) -> None:
directory = 'results' + str(inlier_classes[0]) resultsample = decoder(sample).cpu()
os.makedirs(directory, exist_ok=True) grid = prepare_image(resultsample)
grid = prepare_image(resultsample) summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), max_images=1, step=global_step_decoder)
max_images=1, step=global_step_decoder) _save_sample(**checkpointables)
from PIL import Image
filename = os.path.join(result_prefix, 'sample_' + str(epoch) + '.png')
ndarr = grid.cpu().numpy()
im = Image.fromarray(ndarr)
im.save(filename)
# save weights at end of epoch # save weights at end of epoch
checkpoint.save(checkpoint_prefix) checkpoint.save(checkpoint_prefix)
@ -298,29 +194,30 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
# check for improvements in error reduction - otherwise early stopping # check for improvements in error reduction - otherwise early stopping
strike = False strike = False
total_strike = False total_strike = False
total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
outputs['xd_loss'] + outputs['zd_loss']
if total_loss < total_lowest_loss: if total_loss < total_lowest_loss:
total_lowest_loss = total_loss total_lowest_loss = total_loss
elif total_loss > 6: elif total_loss > 6:
total_strike = True total_strike = True
if encoder_loss < encoder_lowest_loss: if outputs['encoder_loss'] < encoder_lowest_loss:
encoder_lowest_loss = encoder_loss encoder_lowest_loss = outputs['encoder_loss']
else: else:
strike = True strike = True
if decoder_loss < decoder_lowest_loss: if outputs['decoder_loss'] < decoder_lowest_loss:
decoder_lowest_loss = decoder_loss decoder_lowest_loss = outputs['decoder_loss']
else: else:
strike = True strike = True
if enc_dec_loss < enc_dec_lowest_loss: if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
enc_dec_lowest_loss = enc_dec_loss enc_dec_lowest_loss = outputs['enc_dec_loss']
else: else:
strike = True strike = True
if xd_loss < xd_lowest_loss: if outputs['xd_loss'] < xd_lowest_loss:
xd_lowest_loss = xd_loss xd_lowest_loss = outputs['xd_loss']
else: else:
strike = True strike = True
if zd_loss < zd_lowest_loss: if outputs['zd_loss'] < zd_lowest_loss:
zd_lowest_loss = zd_loss zd_lowest_loss = outputs['zd_loss']
else: else:
strike = True strike = True
@ -344,6 +241,114 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
checkpoint.save(checkpoint_prefix) checkpoint.save(checkpoint_prefix)
def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor,
verbose: bool,
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable],
learning_rate_var: tf.Variable,
decoder: Decoder, encoder: Encoder, x_discriminator: XDiscriminator,
z_discriminator: ZDiscriminator, decoder_optimizer: tf.train.Optimizer,
x_discriminator_optimizer: tf.train.Optimizer,
z_discriminator_optimizer: tf.train.Optimizer,
enc_dec_optimizer: tf.train.Optimizer,
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
global_step_decoder: tf.Variable,
global_step_enc_dec: tf.Variable) -> Dict[str, float]:
epoch_start_time = time.time()
# define loss variables
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32)
xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32)
# update learning rate
if (epoch + 1) % 30 == 0:
learning_rate_var.assign(learning_rate_var.value() / 4)
if verbose:
print("learning rate change!")
log_frequency = 10
batch_iteration = k.variable(0, dtype=tf.int64)
for x, _ in dataset:
# x discriminator
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
decoder=decoder,
optimizer=x_discriminator_optimizer,
inputs=x,
targets_real=targets_real,
targets_fake=targets_fake,
global_step=global_step_xd,
z_generator=z_generator)
xd_loss_avg(_xd_train_loss)
# --------
# decoder
_decoder_train_loss = train_decoder_step(decoder=decoder,
x_discriminator=x_discriminator,
optimizer=decoder_optimizer,
targets=targets_real,
global_step=global_step_decoder,
z_generator=z_generator)
decoder_loss_avg(_decoder_train_loss)
# ---------
# z discriminator
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
encoder=encoder,
optimizer=z_discriminator_optimizer,
inputs=x,
targets_real=targets_real,
targets_fake=targets_fake,
global_step=global_step_zd,
z_generator=z_generator)
zd_loss_avg(_zd_train_loss)
# -----------
# encoder + decoder
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
decoder=decoder,
z_discriminator=z_discriminator,
optimizer=enc_dec_optimizer,
inputs=x,
targets=targets_real,
global_step=global_step_enc_dec)
enc_dec_loss_avg(reconstruction_loss)
encoder_loss_avg(encoder_loss)
if int(global_step_decoder % log_frequency) == 0:
# log the losses every log frequency batches
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec)
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder)
summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec)
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd)
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd)
if int(batch_iteration) == 0:
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
grid = prepare_image(comparison.cpu(), nrow=64)
summary_ops_v2.image(name='reconstruction_' + str(epoch),
tensor=k.expand_dims(grid, axis=0), max_images=1,
step=global_step_decoder)
batch_iteration.assign_add(1)
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
# final losses of epoch
outputs = {
'decoder_loss': decoder_loss_avg.result(False),
'encoder_loss': encoder_loss_avg.result(False),
'enc_dec_loss': enc_dec_loss_avg.result(False),
'xd_loss': xd_loss_avg.result(False),
'zd_loss': zd_loss_avg.result(False),
'per_epoch_time': per_epoch_time,
}
return outputs
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor, inputs: tf.Tensor, targets_real: tf.Tensor,
@ -525,5 +530,4 @@ if __name__ == "__main__":
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
train(dataset=train_dataset, iteration=iteration, train(dataset=train_dataset, iteration=iteration,
result_prefix='results' + str(inlier_classes[0]) + '/',
weights_prefix='weights/' + str(inlier_classes[0]) + '/') weights_prefix='weights/' + str(inlier_classes[0]) + '/')