Added visualization for latent space

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-17 11:08:20 +02:00
parent e7e8e15d7f
commit eae908c617

View File

@ -161,16 +161,17 @@ def _train_one_epoch_simple(epoch: int,
print("learning rate change!") print("learning rate change!")
for x in dataset: for x in dataset:
reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder, reconstruction_loss, x_decoded, z = _train_enc_dec_step_simple(encoder=encoder,
decoder=decoder, decoder=decoder,
optimizer=enc_dec_optimizer, optimizer=enc_dec_optimizer,
inputs=x, inputs=x,
global_step_enc_dec=global_step_enc_dec, global_step_enc_dec=global_step_enc_dec,
global_step=global_step) global_step=global_step)
enc_dec_loss_avg(reconstruction_loss) enc_dec_loss_avg(reconstruction_loss)
if int(global_step % LOG_FREQUENCY) == 0 and verbose: if int(global_step % LOG_FREQUENCY) == 0 and verbose:
comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)]], axis=0) comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)],
z[:int(batch_size/2)]], axis=0)
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2)) grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
summary_ops_v2.image(name='reconstruction', summary_ops_v2.image(name='reconstruction',
tensor=K.expand_dims(grid, axis=0), max_images=1, tensor=K.expand_dims(grid, axis=0), max_images=1,
@ -193,7 +194,7 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
optimizer: tf.train.Optimizer, optimizer: tf.train.Optimizer,
inputs: tf.Tensor, inputs: tf.Tensor,
global_step: tf.Variable, global_step: tf.Variable,
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor]: global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
""" """
Trains the encoder and decoder jointly for one step (one batch). Trains the encoder and decoder jointly for one step (one batch).
@ -203,7 +204,7 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
:param inputs: inputs from data set :param inputs: inputs from data set
:param global_step: the global step variable :param global_step: the global step variable
:param global_step_enc_dec: global step variable for enc_dec :param global_step_enc_dec: global step variable for enc_dec
:return: tuple of reconstruction loss, reconstructed input :return: tuple of reconstruction loss, reconstructed input, z value
""" """
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
z = encoder(inputs) z = encoder(inputs)
@ -224,8 +225,11 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder,
optimizer.apply_gradients(zip(enc_dec_grads, optimizer.apply_gradients(zip(enc_dec_grads,
encoder.trainable_variables + decoder.trainable_variables), encoder.trainable_variables + decoder.trainable_variables),
global_step=global_step_enc_dec) global_step=global_step_enc_dec)
input_shape = tf.shape(inputs)
z_reshaped = tf.reshape(z, [-1, input_shape[1], input_shape[2], 1])
z_expanded = K.concatenate((z_reshaped, z_reshaped, z_reshaped), axis=3)
return reconstruction_loss, x_decoded return reconstruction_loss, x_decoded, z_expanded
if __name__ == "__main__": if __name__ == "__main__":