Added latent space visualization to run situation as well

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-17 12:08:59 +02:00
parent 710c8359b3
commit 97b1b8963f

View File

@ -97,14 +97,15 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset,
dtype=tf.float32)
for x in dataset:
reconstruction_loss, x_decoded = _run_enc_dec_step_simple(encoder=encoder,
decoder=decoder,
inputs=x,
global_step=global_step)
reconstruction_loss, x_decoded, z = _run_enc_dec_step_simple(encoder=encoder,
decoder=decoder,
inputs=x,
global_step=global_step)
enc_dec_loss_avg(reconstruction_loss)
if int(global_step % train.LOG_FREQUENCY) == 0:
comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)]], axis=0)
comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2),
z[:int(batch_size / 2)]]], axis=0)
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size / 2))
summary_ops_v2.image(name='reconstruction',
tensor=K.expand_dims(grid, axis=0), max_images=1,
@ -125,7 +126,7 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset,
def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
inputs: tf.Tensor,
global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor]:
global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Runs the encoder and decoder jointly for one step (one batch).
@ -136,7 +137,7 @@ def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
global_step: the global step variable
Returns:
tuple of reconstruction loss, reconstructed input
tuple of reconstruction loss, reconstructed input, latent space value
"""
z = encoder(inputs)
x_decoded = decoder(z)
@ -146,5 +147,9 @@ def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
if int(global_step % train.LOG_FREQUENCY) == 0:
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
step=global_step)
input_shape = tf.shape(inputs)
z_reshaped = tf.reshape(z, [-1, input_shape[1], input_shape[2], 1])
z_concatenated = K.concatenate((z_reshaped, z_reshaped, z_reshaped), axis=3)
return reconstruction_loss, x_decoded
return reconstruction_loss, x_decoded, z_concatenated