diff --git a/src/twomartens/masterthesis/aae/run.py b/src/twomartens/masterthesis/aae/run.py index e9e2f84..0fc8f54 100644 --- a/src/twomartens/masterthesis/aae/run.py +++ b/src/twomartens/masterthesis/aae/run.py @@ -97,14 +97,15 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset, dtype=tf.float32) for x in dataset: - reconstruction_loss, x_decoded = _run_enc_dec_step_simple(encoder=encoder, - decoder=decoder, - inputs=x, - global_step=global_step) + reconstruction_loss, x_decoded, z = _run_enc_dec_step_simple(encoder=encoder, + decoder=decoder, + inputs=x, + global_step=global_step) enc_dec_loss_avg(reconstruction_loss) if int(global_step % train.LOG_FREQUENCY) == 0: - comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)]], axis=0) + comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2), + z[:int(batch_size / 2)]]], axis=0) grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size / 2)) summary_ops_v2.image(name='reconstruction', tensor=K.expand_dims(grid, axis=0), max_images=1, @@ -125,7 +126,7 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset, def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, inputs: tf.Tensor, - global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor]: + global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: """ Runs the encoder and decoder jointly for one step (one batch). @@ -136,7 +137,7 @@ def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, global_step: the global step variable Returns: - tuple of reconstruction loss, reconstructed input + tuple of reconstruction loss, reconstructed input, latent space value """ z = encoder(inputs) x_decoded = decoder(z) @@ -146,5 +147,9 @@ def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, if int(global_step % train.LOG_FREQUENCY) == 0: summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss, step=global_step) + + input_shape = tf.shape(inputs) + z_reshaped = tf.reshape(z, [-1, input_shape[1], input_shape[2], 1]) + z_concatenated = K.concatenate((z_reshaped, z_reshaped, z_reshaped), axis=3) - return reconstruction_loss, x_decoded + return reconstruction_loss, x_decoded, z_concatenated