Introduced global steps for each optimizer
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -101,7 +101,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
y_real_z = k.ones(batch_size)
|
||||
y_fake_z = k.zeros(batch_size)
|
||||
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
global_step_decoder = k.variable(0)
|
||||
global_step_enc_dec = k.variable(0)
|
||||
global_step_xd = k.variable(0)
|
||||
global_step_zd = k.variable(0)
|
||||
|
||||
encoder_loss_history = []
|
||||
decoder_loss_history = []
|
||||
@ -151,7 +154,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
|
||||
x_discriminator_optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
|
||||
global_step=global_step)
|
||||
global_step=global_step_xd)
|
||||
xd_loss_avg(_xd_train_loss)
|
||||
|
||||
# --------
|
||||
@ -166,7 +169,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
||||
decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables),
|
||||
global_step=global_step)
|
||||
global_step=global_step_decoder)
|
||||
decoder_loss_avg(_decoder_train_loss)
|
||||
|
||||
# ---------
|
||||
@ -186,7 +189,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
|
||||
z_discriminator_optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
|
||||
global_step=global_step)
|
||||
global_step=global_step_zd)
|
||||
zd_loss_avg(_zd_train_loss)
|
||||
|
||||
# -----------
|
||||
@ -203,7 +206,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
||||
encoder.trainable_variables + decoder.trainable_variables)
|
||||
enc_dec_optimizer.apply_gradients(zip(enc_dec_grads,
|
||||
encoder.trainable_variables + decoder.trainable_variables))
|
||||
encoder.trainable_variables + decoder.trainable_variables),
|
||||
global_step=global_step_enc_dec)
|
||||
enc_dec_loss_avg(recovery_loss)
|
||||
encoder_loss_avg(encoder_loss)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user