Moved training of epoch into separate function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -20,7 +20,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Sequence, Tuple
|
from typing import Callable, Dict, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -96,7 +96,7 @@ def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], tota
|
|||||||
return train_dataset, valid_dataset
|
return train_dataset, valid_dataset
|
||||||
|
|
||||||
|
|
||||||
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
def train(dataset: tf.data.Dataset, iteration: int,
|
||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||||
batch_size: int = 128, train_epoch: int = 80,
|
batch_size: int = 128, train_epoch: int = 80,
|
||||||
@ -106,7 +106,6 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
|
|
||||||
:param dataset: train dataset
|
:param dataset: train dataset
|
||||||
:param iteration: identifier for the current training run
|
:param iteration: identifier for the current training run
|
||||||
:param result_prefix: prefix for result images
|
|
||||||
:param weights_prefix: prefix for weights directory
|
:param weights_prefix: prefix for weights directory
|
||||||
:param channels: number of channels in input image (default: 1)
|
:param channels: number of channels in input image (default: 1)
|
||||||
:param zsize: size of the intermediary z (default: 32)
|
:param zsize: size of the intermediary z (default: 32)
|
||||||
@ -116,31 +115,14 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
:param verbose: if True prints train progress info to console (default: True)
|
:param verbose: if True prints train progress info to console (default: True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# get models
|
# non-preserved tensors
|
||||||
encoder = Encoder(zsize)
|
|
||||||
decoder = Decoder(channels)
|
|
||||||
z_discriminator = ZDiscriminator()
|
|
||||||
x_discriminator = XDiscriminator()
|
|
||||||
|
|
||||||
# define optimizers
|
|
||||||
learning_rate_var = k.variable(lr)
|
|
||||||
decoder_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
|
||||||
enc_dec_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
|
||||||
z_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
|
||||||
x_discriminator_optimizer = AdamOptimizer(learning_rate=learning_rate_var, beta1=0.5, beta2=0.999)
|
|
||||||
|
|
||||||
# train
|
|
||||||
y_real = k.ones(batch_size)
|
y_real = k.ones(batch_size)
|
||||||
y_fake = k.zeros(batch_size)
|
y_fake = k.zeros(batch_size)
|
||||||
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
||||||
|
# z generator function
|
||||||
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||||
|
|
||||||
global_step_decoder = k.variable(0, dtype=tf.int64)
|
# non-preserved python variables
|
||||||
global_step_enc_dec = k.variable(0, dtype=tf.int64)
|
|
||||||
global_step_xd = k.variable(0, dtype=tf.int64)
|
|
||||||
global_step_zd = k.variable(0, dtype=tf.int64)
|
|
||||||
|
|
||||||
encoder_lowest_loss = math.inf
|
encoder_lowest_loss = math.inf
|
||||||
decoder_lowest_loss = math.inf
|
decoder_lowest_loss = math.inf
|
||||||
enc_dec_lowest_loss = math.inf
|
enc_dec_lowest_loss = math.inf
|
||||||
@ -149,148 +131,62 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
total_lowest_loss = math.inf
|
total_lowest_loss = math.inf
|
||||||
grace_period = GRACE
|
grace_period = GRACE
|
||||||
|
|
||||||
|
# checkpointed tensors and variables
|
||||||
|
checkpointables = {
|
||||||
|
'learning_rate_var': k.variable(lr),
|
||||||
|
}
|
||||||
|
checkpointables.update({
|
||||||
|
# get models
|
||||||
|
'encoder': Encoder(zsize),
|
||||||
|
'decoder': Decoder(channels),
|
||||||
|
'z_discriminator': ZDiscriminator(),
|
||||||
|
'x_discriminator': XDiscriminator(),
|
||||||
|
# define optimizers
|
||||||
|
'decoder_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
||||||
|
'enc_dec_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
||||||
|
'z_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
|
beta1=0.5, beta2=0.999),
|
||||||
|
'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
|
beta1=0.5, beta2=0.999),
|
||||||
|
# global step counter
|
||||||
|
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
||||||
|
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
||||||
|
'global_step_xd': k.variable(0, dtype=tf.int64),
|
||||||
|
'global_step_zd': k.variable(0, dtype=tf.int64),
|
||||||
|
})
|
||||||
|
|
||||||
|
# checkpoint
|
||||||
checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/')
|
checkpoint_dir = os.path.join(weights_prefix, str(iteration) + '/')
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
||||||
checkpoint = tf.train.Checkpoint(encoder=encoder,
|
checkpoint = tf.train.Checkpoint(**checkpointables)
|
||||||
decoder=decoder,
|
checkpoint.restore(latest_checkpoint)
|
||||||
z_discriminator=z_discriminator,
|
|
||||||
x_discriminator=x_discriminator,
|
|
||||||
decoder_optimizer=decoder_optimizer,
|
|
||||||
z_discriminator_optimizer=z_discriminator_optimizer,
|
|
||||||
x_discriminator_optimizer=x_discriminator_optimizer,
|
|
||||||
enc_dec_optimizer=enc_dec_optimizer,
|
|
||||||
global_step_decoder=global_step_decoder,
|
|
||||||
global_step_enc_dec=global_step_enc_dec,
|
|
||||||
global_step_xd=global_step_xd,
|
|
||||||
global_step_zd=global_step_zd,
|
|
||||||
learning_rate_var=learning_rate_var)
|
|
||||||
if latest_checkpoint is not None:
|
|
||||||
# if there is a checkpoint in the current training iteration, proceed from there
|
|
||||||
checkpoint.restore(latest_checkpoint)
|
|
||||||
|
|
||||||
for epoch in range(train_epoch):
|
for epoch in range(train_epoch):
|
||||||
# define loss variables
|
outputs = _train_one_epoch(epoch, dataset, targets_real=y_real,
|
||||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
targets_fake=y_fake, z_generator= z_generator,
|
||||||
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
verbose=verbose,
|
||||||
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
**checkpointables)
|
||||||
zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32)
|
|
||||||
xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32)
|
|
||||||
|
|
||||||
epoch_start_time = time.time()
|
|
||||||
|
|
||||||
# update learning rate
|
|
||||||
if (epoch + 1) % 30 == 0:
|
|
||||||
learning_rate_var.assign(learning_rate_var.value() / 4)
|
|
||||||
if verbose:
|
|
||||||
print("learning rate change!")
|
|
||||||
|
|
||||||
log_frequency = 10
|
|
||||||
batch_iteration = k.variable(0, dtype=tf.int64)
|
|
||||||
for x, _ in dataset:
|
|
||||||
# x discriminator
|
|
||||||
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
|
||||||
decoder=decoder,
|
|
||||||
optimizer=x_discriminator_optimizer,
|
|
||||||
inputs=x,
|
|
||||||
targets_real=y_real,
|
|
||||||
targets_fake=y_fake,
|
|
||||||
global_step=global_step_xd,
|
|
||||||
z_generator=z_generator)
|
|
||||||
xd_loss_avg(_xd_train_loss)
|
|
||||||
|
|
||||||
# --------
|
|
||||||
# decoder
|
|
||||||
_decoder_train_loss = train_decoder_step(decoder=decoder,
|
|
||||||
x_discriminator=x_discriminator,
|
|
||||||
optimizer=decoder_optimizer,
|
|
||||||
targets=y_real,
|
|
||||||
global_step=global_step_decoder,
|
|
||||||
z_generator=z_generator)
|
|
||||||
decoder_loss_avg(_decoder_train_loss)
|
|
||||||
|
|
||||||
# ---------
|
|
||||||
# z discriminator
|
|
||||||
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
|
|
||||||
encoder=encoder,
|
|
||||||
optimizer=z_discriminator_optimizer,
|
|
||||||
inputs=x,
|
|
||||||
targets_real=y_real,
|
|
||||||
targets_fake=y_fake,
|
|
||||||
global_step=global_step_zd,
|
|
||||||
z_generator=z_generator)
|
|
||||||
zd_loss_avg(_zd_train_loss)
|
|
||||||
|
|
||||||
# -----------
|
|
||||||
# encoder + decoder
|
|
||||||
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
z_discriminator=z_discriminator,
|
|
||||||
optimizer=enc_dec_optimizer,
|
|
||||||
inputs=x,
|
|
||||||
targets=y_real,
|
|
||||||
global_step=global_step_enc_dec)
|
|
||||||
enc_dec_loss_avg(reconstruction_loss)
|
|
||||||
encoder_loss_avg(encoder_loss)
|
|
||||||
|
|
||||||
if int(global_step_decoder % log_frequency) == 0:
|
|
||||||
# log the losses every log frequency batches
|
|
||||||
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec)
|
|
||||||
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder)
|
|
||||||
summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec)
|
|
||||||
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd)
|
|
||||||
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd)
|
|
||||||
|
|
||||||
if int(batch_iteration) == 0:
|
|
||||||
directory = 'results' + str(inlier_classes[0])
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
os.makedirs(directory)
|
|
||||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
|
||||||
grid = prepare_image(comparison.cpu(), nrow=64)
|
|
||||||
summary_ops_v2.image(name='reconstruction_' + str(epoch),
|
|
||||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
|
||||||
step=global_step_decoder)
|
|
||||||
from PIL import Image
|
|
||||||
filename = os.path.join(result_prefix, 'reconstruction_' + str(epoch) + '.png')
|
|
||||||
ndarr = grid.cpu().numpy()
|
|
||||||
im = Image.fromarray(ndarr)
|
|
||||||
im.save(filename)
|
|
||||||
|
|
||||||
batch_iteration.assign_add(1)
|
|
||||||
|
|
||||||
epoch_end_time = time.time()
|
|
||||||
per_epoch_time = epoch_end_time - epoch_start_time
|
|
||||||
|
|
||||||
# final losses of epoch
|
|
||||||
decoder_loss = decoder_loss_avg.result(False)
|
|
||||||
encoder_loss = encoder_loss_avg.result(False)
|
|
||||||
enc_dec_loss = enc_dec_loss_avg.result(False)
|
|
||||||
xd_loss = xd_loss_avg.result(False)
|
|
||||||
zd_loss = zd_loss_avg.result(False)
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print((
|
print((
|
||||||
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
||||||
f"train time: {per_epoch_time:.2f}, "
|
f"train time: {outputs['per_epoch_time']:.2f}, "
|
||||||
f"Decoder loss: {decoder_loss:.3f}, "
|
f"Decoder loss: {outputs['decoder_loss']:.3f}, "
|
||||||
f"X Discriminator loss: {xd_loss:.3f}, "
|
f"X Discriminator loss: {outputs['xd_loss']:.3f}, "
|
||||||
f"Z Discriminator loss: {zd_loss:.3f}, "
|
f"Z Discriminator loss: {outputs['zd_loss']:.3f}, "
|
||||||
f"Encoder + Decoder loss: {enc_dec_loss:.3f}, "
|
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, "
|
||||||
f"Encoder loss: {encoder_loss:.3f}"
|
f"Encoder loss: {outputs['encoder_loss']:.3f}"
|
||||||
))
|
))
|
||||||
|
|
||||||
# save sample image
|
# save sample image summary
|
||||||
resultsample = decoder(sample).cpu()
|
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable) -> None:
|
||||||
directory = 'results' + str(inlier_classes[0])
|
resultsample = decoder(sample).cpu()
|
||||||
os.makedirs(directory, exist_ok=True)
|
grid = prepare_image(resultsample)
|
||||||
grid = prepare_image(resultsample)
|
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
||||||
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
max_images=1, step=global_step_decoder)
|
||||||
max_images=1, step=global_step_decoder)
|
_save_sample(**checkpointables)
|
||||||
from PIL import Image
|
|
||||||
filename = os.path.join(result_prefix, 'sample_' + str(epoch) + '.png')
|
|
||||||
ndarr = grid.cpu().numpy()
|
|
||||||
im = Image.fromarray(ndarr)
|
|
||||||
im.save(filename)
|
|
||||||
|
|
||||||
# save weights at end of epoch
|
# save weights at end of epoch
|
||||||
checkpoint.save(checkpoint_prefix)
|
checkpoint.save(checkpoint_prefix)
|
||||||
@ -298,29 +194,30 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
# check for improvements in error reduction - otherwise early stopping
|
# check for improvements in error reduction - otherwise early stopping
|
||||||
strike = False
|
strike = False
|
||||||
total_strike = False
|
total_strike = False
|
||||||
total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss
|
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||||
|
outputs['xd_loss'] + outputs['zd_loss']
|
||||||
if total_loss < total_lowest_loss:
|
if total_loss < total_lowest_loss:
|
||||||
total_lowest_loss = total_loss
|
total_lowest_loss = total_loss
|
||||||
elif total_loss > 6:
|
elif total_loss > 6:
|
||||||
total_strike = True
|
total_strike = True
|
||||||
if encoder_loss < encoder_lowest_loss:
|
if outputs['encoder_loss'] < encoder_lowest_loss:
|
||||||
encoder_lowest_loss = encoder_loss
|
encoder_lowest_loss = outputs['encoder_loss']
|
||||||
else:
|
else:
|
||||||
strike = True
|
strike = True
|
||||||
if decoder_loss < decoder_lowest_loss:
|
if outputs['decoder_loss'] < decoder_lowest_loss:
|
||||||
decoder_lowest_loss = decoder_loss
|
decoder_lowest_loss = outputs['decoder_loss']
|
||||||
else:
|
else:
|
||||||
strike = True
|
strike = True
|
||||||
if enc_dec_loss < enc_dec_lowest_loss:
|
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
|
||||||
enc_dec_lowest_loss = enc_dec_loss
|
enc_dec_lowest_loss = outputs['enc_dec_loss']
|
||||||
else:
|
else:
|
||||||
strike = True
|
strike = True
|
||||||
if xd_loss < xd_lowest_loss:
|
if outputs['xd_loss'] < xd_lowest_loss:
|
||||||
xd_lowest_loss = xd_loss
|
xd_lowest_loss = outputs['xd_loss']
|
||||||
else:
|
else:
|
||||||
strike = True
|
strike = True
|
||||||
if zd_loss < zd_lowest_loss:
|
if outputs['zd_loss'] < zd_lowest_loss:
|
||||||
zd_lowest_loss = zd_loss
|
zd_lowest_loss = outputs['zd_loss']
|
||||||
else:
|
else:
|
||||||
strike = True
|
strike = True
|
||||||
|
|
||||||
@ -344,6 +241,114 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
checkpoint.save(checkpoint_prefix)
|
checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor,
|
||||||
|
verbose: bool,
|
||||||
|
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable],
|
||||||
|
learning_rate_var: tf.Variable,
|
||||||
|
decoder: Decoder, encoder: Encoder, x_discriminator: XDiscriminator,
|
||||||
|
z_discriminator: ZDiscriminator, decoder_optimizer: tf.train.Optimizer,
|
||||||
|
x_discriminator_optimizer: tf.train.Optimizer,
|
||||||
|
z_discriminator_optimizer: tf.train.Optimizer,
|
||||||
|
enc_dec_optimizer: tf.train.Optimizer,
|
||||||
|
global_step_xd: tf.Variable, global_step_zd: tf.Variable,
|
||||||
|
global_step_decoder: tf.Variable,
|
||||||
|
global_step_enc_dec: tf.Variable) -> Dict[str, float]:
|
||||||
|
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
# define loss variables
|
||||||
|
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||||
|
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
||||||
|
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
||||||
|
zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32)
|
||||||
|
xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32)
|
||||||
|
|
||||||
|
# update learning rate
|
||||||
|
if (epoch + 1) % 30 == 0:
|
||||||
|
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||||
|
if verbose:
|
||||||
|
print("learning rate change!")
|
||||||
|
|
||||||
|
log_frequency = 10
|
||||||
|
batch_iteration = k.variable(0, dtype=tf.int64)
|
||||||
|
for x, _ in dataset:
|
||||||
|
# x discriminator
|
||||||
|
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||||
|
decoder=decoder,
|
||||||
|
optimizer=x_discriminator_optimizer,
|
||||||
|
inputs=x,
|
||||||
|
targets_real=targets_real,
|
||||||
|
targets_fake=targets_fake,
|
||||||
|
global_step=global_step_xd,
|
||||||
|
z_generator=z_generator)
|
||||||
|
xd_loss_avg(_xd_train_loss)
|
||||||
|
|
||||||
|
# --------
|
||||||
|
# decoder
|
||||||
|
_decoder_train_loss = train_decoder_step(decoder=decoder,
|
||||||
|
x_discriminator=x_discriminator,
|
||||||
|
optimizer=decoder_optimizer,
|
||||||
|
targets=targets_real,
|
||||||
|
global_step=global_step_decoder,
|
||||||
|
z_generator=z_generator)
|
||||||
|
decoder_loss_avg(_decoder_train_loss)
|
||||||
|
|
||||||
|
# ---------
|
||||||
|
# z discriminator
|
||||||
|
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
|
||||||
|
encoder=encoder,
|
||||||
|
optimizer=z_discriminator_optimizer,
|
||||||
|
inputs=x,
|
||||||
|
targets_real=targets_real,
|
||||||
|
targets_fake=targets_fake,
|
||||||
|
global_step=global_step_zd,
|
||||||
|
z_generator=z_generator)
|
||||||
|
zd_loss_avg(_zd_train_loss)
|
||||||
|
|
||||||
|
# -----------
|
||||||
|
# encoder + decoder
|
||||||
|
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
z_discriminator=z_discriminator,
|
||||||
|
optimizer=enc_dec_optimizer,
|
||||||
|
inputs=x,
|
||||||
|
targets=targets_real,
|
||||||
|
global_step=global_step_enc_dec)
|
||||||
|
enc_dec_loss_avg(reconstruction_loss)
|
||||||
|
encoder_loss_avg(encoder_loss)
|
||||||
|
|
||||||
|
if int(global_step_decoder % log_frequency) == 0:
|
||||||
|
# log the losses every log frequency batches
|
||||||
|
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec)
|
||||||
|
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder)
|
||||||
|
summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec)
|
||||||
|
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd)
|
||||||
|
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd)
|
||||||
|
|
||||||
|
if int(batch_iteration) == 0:
|
||||||
|
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||||
|
grid = prepare_image(comparison.cpu(), nrow=64)
|
||||||
|
summary_ops_v2.image(name='reconstruction_' + str(epoch),
|
||||||
|
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
||||||
|
step=global_step_decoder)
|
||||||
|
|
||||||
|
batch_iteration.assign_add(1)
|
||||||
|
|
||||||
|
epoch_end_time = time.time()
|
||||||
|
per_epoch_time = epoch_end_time - epoch_start_time
|
||||||
|
|
||||||
|
# final losses of epoch
|
||||||
|
outputs = {
|
||||||
|
'decoder_loss': decoder_loss_avg.result(False),
|
||||||
|
'encoder_loss': encoder_loss_avg.result(False),
|
||||||
|
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
||||||
|
'xd_loss': xd_loss_avg.result(False),
|
||||||
|
'zd_loss': zd_loss_avg.result(False),
|
||||||
|
'per_epoch_time': per_epoch_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||||
@ -525,5 +530,4 @@ if __name__ == "__main__":
|
|||||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||||
train(dataset=train_dataset, iteration=iteration,
|
train(dataset=train_dataset, iteration=iteration,
|
||||||
result_prefix='results' + str(inlier_classes[0]) + '/',
|
|
||||||
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
||||||
|
|||||||
Reference in New Issue
Block a user