diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 94275db..4c8d0e9 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -211,7 +211,7 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, reconstruction_loss = tf.losses.log_loss(inputs, x_decoded) - enc_dec_grads = tape.gradient(_enc_dec_train_loss, + enc_dec_grads = tape.gradient(reconstruction_loss, encoder.trainable_variables + decoder.trainable_variables) if int(global_step % LOG_FREQUENCY) == 0: summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,