Introduced global steps for each optimizer

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 05:55:19 +01:00
parent bb76a51d76
commit 887435ddd7

View File

@ -101,7 +101,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
y_real_z = k.ones(batch_size)
y_fake_z = k.zeros(batch_size)
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
global_step = tf.train.get_or_create_global_step()
global_step_decoder = k.variable(0)
global_step_enc_dec = k.variable(0)
global_step_xd = k.variable(0)
global_step_zd = k.variable(0)
encoder_loss_history = []
decoder_loss_history = []
@ -151,7 +154,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
x_discriminator_optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
global_step=global_step)
global_step=global_step_xd)
xd_loss_avg(_xd_train_loss)
# --------
@ -166,7 +169,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables),
global_step=global_step)
global_step=global_step_decoder)
decoder_loss_avg(_decoder_train_loss)
# ---------
@ -186,7 +189,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
z_discriminator_optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
global_step=global_step)
global_step=global_step_zd)
zd_loss_avg(_zd_train_loss)
# -----------
@ -203,7 +206,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
encoder.trainable_variables + decoder.trainable_variables)
enc_dec_optimizer.apply_gradients(zip(enc_dec_grads,
encoder.trainable_variables + decoder.trainable_variables))
encoder.trainable_variables + decoder.trainable_variables),
global_step=global_step_enc_dec)
enc_dec_loss_avg(recovery_loss)
encoder_loss_avg(encoder_loss)