diff --git a/src/twomartens/masterthesis/aae/run.py b/src/twomartens/masterthesis/aae/run.py index 51eca36..aca79b3 100644 --- a/src/twomartens/masterthesis/aae/run.py +++ b/src/twomartens/masterthesis/aae/run.py @@ -97,15 +97,14 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset, dtype=tf.float32) for x in dataset: - reconstruction_loss, x_decoded, z = _run_enc_dec_step_simple(encoder=encoder, - decoder=decoder, - inputs=x, - global_step=global_step) + reconstruction_loss, x_decoded = _run_enc_dec_step_simple(encoder=encoder, + decoder=decoder, + inputs=x, + global_step=global_step) enc_dec_loss_avg(reconstruction_loss) if int(global_step % train.LOG_FREQUENCY) == 0: - comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)], - z[:int(batch_size / 2)]], axis=0) + comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)]], axis=0) grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size / 2)) summary_ops_v2.image(name='reconstruction', tensor=K.expand_dims(grid, axis=0), max_images=1, @@ -126,7 +125,7 @@ def _run_one_epoch_simple(dataset: tf.data.Dataset, def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, inputs: tf.Tensor, - global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor]: """ Runs the encoder and decoder jointly for one step (one batch). @@ -147,9 +146,5 @@ def _run_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, if int(global_step % train.LOG_FREQUENCY) == 0: summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss, step=global_step) - - input_shape = tf.shape(inputs) - z_reshaped = tf.reshape(z, [-1, input_shape[1], input_shape[2], 1]) - z_concatenated = K.concatenate((z_reshaped, z_reshaped, z_reshaped), axis=3) - return reconstruction_loss, x_decoded, z_concatenated + return reconstruction_loss, x_decoded diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 658fffc..6cc77d9 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -161,17 +161,16 @@ def _train_one_epoch_simple(epoch: int, print("learning rate change!") for x in dataset: - reconstruction_loss, x_decoded, z = _train_enc_dec_step_simple(encoder=encoder, - decoder=decoder, - optimizer=enc_dec_optimizer, - inputs=x, - global_step_enc_dec=global_step_enc_dec, - global_step=global_step) + reconstruction_loss, x_decoded = _train_enc_dec_step_simple(encoder=encoder, + decoder=decoder, + optimizer=enc_dec_optimizer, + inputs=x, + global_step_enc_dec=global_step_enc_dec, + global_step=global_step) enc_dec_loss_avg(reconstruction_loss) if int(global_step % LOG_FREQUENCY) == 0: - comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)], - z[:int(batch_size/2)]], axis=0) + comparison = K.concatenate([x[:int(batch_size / 2)], x_decoded[:int(batch_size / 2)]], axis=0) grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2)) summary_ops_v2.image(name='reconstruction', tensor=K.expand_dims(grid, axis=0), max_images=1, @@ -194,8 +193,7 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, optimizer: tf.train.Optimizer, inputs: tf.Tensor, global_step: tf.Variable, - global_step_enc_dec: tf.Variable, - debug: bool) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor]: """ Trains the encoder and decoder jointly for one step (one batch). @@ -229,11 +227,8 @@ def _train_enc_dec_step_simple(encoder: model.Encoder, decoder: model.Decoder, optimizer.apply_gradients(zip(enc_dec_grads, encoder.trainable_variables + decoder.trainable_variables), global_step=global_step_enc_dec) - input_shape = tf.shape(inputs) - z_reshaped = tf.reshape(z, [-1, input_shape[1], input_shape[2], 1]) - z_expanded = K.concatenate((z_reshaped, z_reshaped, z_reshaped), axis=3) - return reconstruction_loss, x_decoded, z_expanded + return reconstruction_loss, x_decoded if __name__ == "__main__":