diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 6f3fe07..8376e7f 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -101,7 +101,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i y_real_z = k.ones(batch_size) y_fake_z = k.zeros(batch_size) sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1) - global_step = tf.train.get_or_create_global_step() + global_step_decoder = k.variable(0) + global_step_enc_dec = k.variable(0) + global_step_xd = k.variable(0) + global_step_zd = k.variable(0) encoder_loss_history = [] decoder_loss_history = [] @@ -151,7 +154,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables) x_discriminator_optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables), - global_step=global_step) + global_step=global_step_xd) xd_loss_avg(_xd_train_loss) # -------- @@ -166,7 +169,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables) decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables), - global_step=global_step) + global_step=global_step_decoder) decoder_loss_avg(_decoder_train_loss) # --------- @@ -186,7 +189,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables) z_discriminator_optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables), - global_step=global_step) + global_step=global_step_zd) zd_loss_avg(_zd_train_loss) # ----------- @@ -203,7 +206,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i enc_dec_grads = tape.gradient(_enc_dec_train_loss, encoder.trainable_variables + decoder.trainable_variables) enc_dec_optimizer.apply_gradients(zip(enc_dec_grads, - encoder.trainable_variables + decoder.trainable_variables)) + encoder.trainable_variables + decoder.trainable_variables), + global_step=global_step_enc_dec) enc_dec_loss_avg(recovery_loss) encoder_loss_avg(encoder_loss)