@ -194,7 +194,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
# global step counter
|
# global step counter
|
||||||
'epoch_var': k.variable(-1, dtype=tf.int64),
|
'epoch_var': k.variable(-1, dtype=tf.int64),
|
||||||
'global_step': k.variable(0, dtype=tf.int64),
|
'global_step': tf.train.get_or_create_global_step(),
|
||||||
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
||||||
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
||||||
'global_step_xd': k.variable(0, dtype=tf.int64),
|
'global_step_xd': k.variable(0, dtype=tf.int64),
|
||||||
@ -314,8 +314,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
global_step_enc_dec: tf.Variable,
|
global_step_enc_dec: tf.Variable,
|
||||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||||
|
|
||||||
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY,
|
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY):
|
||||||
global_step=global_step):
|
|
||||||
epoch_var.assign(epoch)
|
epoch_var.assign(epoch)
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
# define loss variables
|
# define loss variables
|
||||||
@ -331,7 +330,6 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
if verbose:
|
if verbose:
|
||||||
print("learning rate change!")
|
print("learning rate change!")
|
||||||
|
|
||||||
batch_iteration = k.variable(0, dtype=tf.int64)
|
|
||||||
for x, _ in dataset:
|
for x, _ in dataset:
|
||||||
# x discriminator
|
# x discriminator
|
||||||
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
|
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||||
@ -378,14 +376,11 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
enc_dec_loss_avg(reconstruction_loss)
|
enc_dec_loss_avg(reconstruction_loss)
|
||||||
encoder_loss_avg(encoder_loss)
|
encoder_loss_avg(encoder_loss)
|
||||||
|
|
||||||
if int(batch_iteration) == 0:
|
|
||||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||||
grid = prepare_image(comparison.cpu(), nrow=64)
|
grid = prepare_image(comparison.cpu(), nrow=64)
|
||||||
summary_ops_v2.image(name='reconstruction',
|
summary_ops_v2.image(name='reconstruction',
|
||||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
||||||
step=global_step_decoder)
|
step=global_step)
|
||||||
|
|
||||||
batch_iteration.assign_add(1)
|
|
||||||
global_step.assign_add(1)
|
global_step.assign_add(1)
|
||||||
|
|
||||||
epoch_end_time = time.time()
|
epoch_end_time = time.time()
|
||||||
|
|||||||
Reference in New Issue
Block a user