@ -194,7 +194,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
beta1=0.5, beta2=0.999),
|
||||
# global step counter
|
||||
'epoch_var': k.variable(-1, dtype=tf.int64),
|
||||
'global_step': k.variable(0, dtype=tf.int64),
|
||||
'global_step': tf.train.get_or_create_global_step(),
|
||||
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
||||
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
||||
'global_step_xd': k.variable(0, dtype=tf.int64),
|
||||
@ -314,8 +314,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
global_step_enc_dec: tf.Variable,
|
||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||
|
||||
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY,
|
||||
global_step=global_step):
|
||||
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY):
|
||||
epoch_var.assign(epoch)
|
||||
epoch_start_time = time.time()
|
||||
# define loss variables
|
||||
@ -331,7 +330,6 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
if verbose:
|
||||
print("learning rate change!")
|
||||
|
||||
batch_iteration = k.variable(0, dtype=tf.int64)
|
||||
for x, _ in dataset:
|
||||
# x discriminator
|
||||
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||
@ -378,14 +376,11 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
||||
enc_dec_loss_avg(reconstruction_loss)
|
||||
encoder_loss_avg(encoder_loss)
|
||||
|
||||
if int(batch_iteration) == 0:
|
||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||
grid = prepare_image(comparison.cpu(), nrow=64)
|
||||
summary_ops_v2.image(name='reconstruction',
|
||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
||||
step=global_step_decoder)
|
||||
|
||||
batch_iteration.assign_add(1)
|
||||
step=global_step)
|
||||
global_step.assign_add(1)
|
||||
|
||||
epoch_end_time = time.time()
|
||||
|
||||
Reference in New Issue
Block a user