diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 8376e7f..ea2288c 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -15,11 +15,12 @@ # limitations under the License. """aae.train.py: contains training functionality""" +import functools import os import pickle import random import time -from typing import Sequence, Tuple +from typing import Callable, Sequence, Tuple import numpy as np import tensorflow as tf @@ -101,6 +102,9 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i y_real_z = k.ones(batch_size) y_fake_z = k.zeros(batch_size) sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1) + + z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize) + global_step_decoder = k.variable(0) global_step_enc_dec = k.variable(0) global_step_xd = k.variable(0) @@ -140,75 +144,48 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i for it in range(nr_batches): x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size)) # x discriminator - with tf.GradientTape() as tape: - xd_result = tf.squeeze(x_discriminator(x)) - xd_real_loss = binary_crossentropy(y_real, xd_result) - z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize)) - z = k.variable(z) - - x_fake = decoder(z) - xd_result = tf.squeeze(x_discriminator(x_fake)) - xd_fake_loss = binary_crossentropy(y_fake, xd_result) - - _xd_train_loss = xd_real_loss + xd_fake_loss - - xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables) - x_discriminator_optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables), - global_step=global_step_xd) + _xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator, + decoder=decoder, + optimizer=x_discriminator_optimizer, + inputs=x, + targets_real=y_real, + targets_fake=y_fake, + global_step=global_step_xd, + z_generator=z_generator) xd_loss_avg(_xd_train_loss) # -------- # decoder - with tf.GradientTape() as tape: - z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize)) - z = k.variable(z) - - x_fake = decoder(z) - xd_result = tf.squeeze(x_discriminator(x_fake)) - _decoder_train_loss = binary_crossentropy(y_real, xd_result) - - decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables) - decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables), - global_step=global_step_decoder) + _decoder_train_loss = train_decoder_step(decoder=decoder, + x_discriminator=x_discriminator, + optimizer=decoder_optimizer, + targets=y_real, + global_step=global_step_decoder, + z_generator=z_generator) decoder_loss_avg(_decoder_train_loss) # --------- # z discriminator - with tf.GradientTape() as tape: - z = k.reshape(k.random_normal((batch_size, zsize)), (-1, zsize)) - z = k.variable(z) - - zd_result = tf.squeeze(z_discriminator(z)) - zd_real_loss = binary_crossentropy(y_real_z, zd_result) - - z = tf.squeeze(encoder(x)) - zd_result = tf.squeeze(z_discriminator(z)) - zd_fake_loss = binary_crossentropy(y_fake_z, zd_result) - - _zd_train_loss = zd_real_loss + zd_fake_loss - - zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables) - z_discriminator_optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables), - global_step=global_step_zd) + _zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator, + encoder=encoder, + optimizer=z_discriminator_optimizer, + inputs=x, + targets_real=y_real, + targets_fake=y_fake, + global_step=global_step_zd, + z_generator=z_generator) zd_loss_avg(_zd_train_loss) # ----------- # encoder + decoder - with tf.GradientTape() as tape: - z = encoder(x) - x_decoded = decoder(z) - - zd_result = tf.squeeze(z_discriminator(tf.squeeze(z))) - encoder_loss = binary_crossentropy(y_real_z, zd_result) * 2.0 - recovery_loss = binary_crossentropy(x, x_decoded) - _enc_dec_train_loss = encoder_loss + recovery_loss - - enc_dec_grads = tape.gradient(_enc_dec_train_loss, - encoder.trainable_variables + decoder.trainable_variables) - enc_dec_optimizer.apply_gradients(zip(enc_dec_grads, - encoder.trainable_variables + decoder.trainable_variables), - global_step=global_step_enc_dec) - enc_dec_loss_avg(recovery_loss) + encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder, + decoder=decoder, + z_discriminator=z_discriminator, + optimizer=enc_dec_optimizer, + inputs=x, + targets=y_real, + global_step=global_step_enc_dec) + enc_dec_loss_avg(reconstruction_loss) encoder_loss_avg(encoder_loss) if it == 0: @@ -261,6 +238,153 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i pickle.dump(zd_loss_history, file) +def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, + optimizer: tf.train.Optimizer, + inputs: tf.Tensor, targets_real: tf.Tensor, + targets_fake: tf.Tensor, global_step: tf.Variable, + z_generator: Callable[[], tf.Variable]) -> tf.Tensor: + """ + Trains the x discriminator model for one step (one batch). + + :param x_discriminator: instance of x discriminator model + :param decoder: instance of decoder model + :param optimizer: instance of chosen optimizer + :param inputs: inputs from dataset + :param targets_real: target tensor for real loss calculation + :param targets_fake: target tensor for fake loss calculation + :param global_step: the global step variable + :param z_generator: callable function that returns a z variable + :return: the calculated loss + """ + with tf.GradientTape() as tape: + xd_result = tf.squeeze(x_discriminator(inputs)) + xd_real_loss = binary_crossentropy(targets_real, xd_result) + + z = z_generator() + x_fake = decoder(z) + xd_result = tf.squeeze(x_discriminator(x_fake)) + xd_fake_loss = binary_crossentropy(targets_fake, xd_result) + + _xd_train_loss = xd_real_loss + xd_fake_loss + + xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables) + optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables), + global_step=global_step) + + return _xd_train_loss + + +def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator, + optimizer: tf.train.Optimizer, + targets: tf.Tensor, global_step: tf.Variable, + z_generator: Callable[[], tf.Variable]) -> tf.Tensor: + """ + Trains the decoder model for one step (one batch). + + :param decoder: instance of decoder model + :param x_discriminator: instance of the x discriminator model + :param optimizer: instance of chosen optimizer + :param targets: target tensor for loss calculation + :param global_step: the global step variable + :param z_generator: callable function that returns a z variable + :return: the calculated loss + """ + with tf.GradientTape() as tape: + z = z_generator() + + x_fake = decoder(z) + xd_result = tf.squeeze(x_discriminator(x_fake)) + _decoder_train_loss = binary_crossentropy(targets, xd_result) + + grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables) + optimizer.apply_gradients(zip(grads, decoder.trainable_variables), + global_step=global_step) + + return _decoder_train_loss + + +def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder, + optimizer: tf.train.Optimizer, + inputs: tf.Tensor, targets_real: tf.Tensor, + targets_fake: tf.Tensor, global_step: tf.Variable, + z_generator: Callable[[], tf.Variable]) -> tf.Tensor: + """ + Trains the z discriminator one step (one batch). + + :param z_discriminator: instance of z discriminator model + :param encoder: instance of encoder model + :param optimizer: instance of chosen optimizer + :param inputs: inputs from dataset + :param targets_real: target tensor for real loss calculation + :param targets_fake: target tensor for fake loss calculation + :param global_step: the global step variable + :param z_generator: callable function that returns a z variable + :return: the calculated loss + """ + with tf.GradientTape() as tape: + z = z_generator() + + zd_result = tf.squeeze(z_discriminator(z)) + zd_real_loss = binary_crossentropy(targets_real, zd_result) + + z = tf.squeeze(encoder(inputs)) + zd_result = tf.squeeze(z_discriminator(z)) + zd_fake_loss = binary_crossentropy(targets_fake, zd_result) + + _zd_train_loss = zd_real_loss + zd_fake_loss + + zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables) + optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables), + global_step=global_step) + + return _zd_train_loss + + +def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator, + optimizer: tf.train.Optimizer, inputs: tf.Tensor, + targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + """ + Trains the encoder and decoder jointly for one step (one batch). + + :param encoder: instance of encoder model + :param decoder: instance of decoder model + :param z_discriminator: instance of z discriminator model + :param optimizer: instance of chosen optimizer + :param inputs: inputs from dataset + :param targets: target tensor for loss calculation + :param global_step: the global step variable + :return: tuple of encoder loss, reconstruction loss, reconstructed input + """ + with tf.GradientTape() as tape: + z = encoder(inputs) + x_decoded = decoder(z) + + zd_result = tf.squeeze(z_discriminator(tf.squeeze(z))) + encoder_loss = binary_crossentropy(targets, zd_result) * 2.0 + reconstruction_loss = binary_crossentropy(inputs, x_decoded) + _enc_dec_train_loss = encoder_loss + reconstruction_loss + + enc_dec_grads = tape.gradient(_enc_dec_train_loss, + encoder.trainable_variables + decoder.trainable_variables) + optimizer.apply_gradients(zip(enc_dec_grads, + encoder.trainable_variables + decoder.trainable_variables), + global_step=global_step) + + return encoder_loss, reconstruction_loss, x_decoded + + +def get_z_variable(batch_size: int, zsize: int) -> tf.Variable: + """ + Creates and returns a z variable taken from a normal distribution. + + :param batch_size: size of the batch + :param zsize: size of the z latent space + :return: created variable + """ + z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize)) + return k.variable(z) + + def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable: """ Extracts a batch from data.