Extracted training steps of models into separate functions
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""aae.train.py: contains training functionality"""
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Callable, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -101,6 +102,9 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
y_real_z = k.ones(batch_size)
|
||||
y_fake_z = k.zeros(batch_size)
|
||||
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
||||
|
||||
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||
|
||||
global_step_decoder = k.variable(0)
|
||||
global_step_enc_dec = k.variable(0)
|
||||
global_step_xd = k.variable(0)
|
||||
@ -140,75 +144,48 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
for it in range(nr_batches):
|
||||
x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size))
|
||||
# x discriminator
|
||||
with tf.GradientTape() as tape:
|
||||
xd_result = tf.squeeze(x_discriminator(x))
|
||||
xd_real_loss = binary_crossentropy(y_real, xd_result)
|
||||
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
|
||||
z = k.variable(z)
|
||||
|
||||
x_fake = decoder(z)
|
||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||
xd_fake_loss = binary_crossentropy(y_fake, xd_result)
|
||||
|
||||
_xd_train_loss = xd_real_loss + xd_fake_loss
|
||||
|
||||
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
|
||||
x_discriminator_optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
|
||||
global_step=global_step_xd)
|
||||
_xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||
decoder=decoder,
|
||||
optimizer=x_discriminator_optimizer,
|
||||
inputs=x,
|
||||
targets_real=y_real,
|
||||
targets_fake=y_fake,
|
||||
global_step=global_step_xd,
|
||||
z_generator=z_generator)
|
||||
xd_loss_avg(_xd_train_loss)
|
||||
|
||||
# --------
|
||||
# decoder
|
||||
with tf.GradientTape() as tape:
|
||||
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
|
||||
z = k.variable(z)
|
||||
|
||||
x_fake = decoder(z)
|
||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||
_decoder_train_loss = binary_crossentropy(y_real, xd_result)
|
||||
|
||||
decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
||||
decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables),
|
||||
global_step=global_step_decoder)
|
||||
_decoder_train_loss = train_decoder_step(decoder=decoder,
|
||||
x_discriminator=x_discriminator,
|
||||
optimizer=decoder_optimizer,
|
||||
targets=y_real,
|
||||
global_step=global_step_decoder,
|
||||
z_generator=z_generator)
|
||||
decoder_loss_avg(_decoder_train_loss)
|
||||
|
||||
# ---------
|
||||
# z discriminator
|
||||
with tf.GradientTape() as tape:
|
||||
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, zsize))
|
||||
z = k.variable(z)
|
||||
|
||||
zd_result = tf.squeeze(z_discriminator(z))
|
||||
zd_real_loss = binary_crossentropy(y_real_z, zd_result)
|
||||
|
||||
z = tf.squeeze(encoder(x))
|
||||
zd_result = tf.squeeze(z_discriminator(z))
|
||||
zd_fake_loss = binary_crossentropy(y_fake_z, zd_result)
|
||||
|
||||
_zd_train_loss = zd_real_loss + zd_fake_loss
|
||||
|
||||
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
|
||||
z_discriminator_optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
|
||||
global_step=global_step_zd)
|
||||
_zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
|
||||
encoder=encoder,
|
||||
optimizer=z_discriminator_optimizer,
|
||||
inputs=x,
|
||||
targets_real=y_real,
|
||||
targets_fake=y_fake,
|
||||
global_step=global_step_zd,
|
||||
z_generator=z_generator)
|
||||
zd_loss_avg(_zd_train_loss)
|
||||
|
||||
# -----------
|
||||
# encoder + decoder
|
||||
with tf.GradientTape() as tape:
|
||||
z = encoder(x)
|
||||
x_decoded = decoder(z)
|
||||
|
||||
zd_result = tf.squeeze(z_discriminator(tf.squeeze(z)))
|
||||
encoder_loss = binary_crossentropy(y_real_z, zd_result) * 2.0
|
||||
recovery_loss = binary_crossentropy(x, x_decoded)
|
||||
_enc_dec_train_loss = encoder_loss + recovery_loss
|
||||
|
||||
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
||||
encoder.trainable_variables + decoder.trainable_variables)
|
||||
enc_dec_optimizer.apply_gradients(zip(enc_dec_grads,
|
||||
encoder.trainable_variables + decoder.trainable_variables),
|
||||
global_step=global_step_enc_dec)
|
||||
enc_dec_loss_avg(recovery_loss)
|
||||
encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
|
||||
decoder=decoder,
|
||||
z_discriminator=z_discriminator,
|
||||
optimizer=enc_dec_optimizer,
|
||||
inputs=x,
|
||||
targets=y_real,
|
||||
global_step=global_step_enc_dec)
|
||||
enc_dec_loss_avg(reconstruction_loss)
|
||||
encoder_loss_avg(encoder_loss)
|
||||
|
||||
if it == 0:
|
||||
@ -261,6 +238,153 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
pickle.dump(zd_loss_history, file)
|
||||
|
||||
|
||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||
optimizer: tf.train.Optimizer,
|
||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||
"""
|
||||
Trains the x discriminator model for one step (one batch).
|
||||
|
||||
:param x_discriminator: instance of x discriminator model
|
||||
:param decoder: instance of decoder model
|
||||
:param optimizer: instance of chosen optimizer
|
||||
:param inputs: inputs from dataset
|
||||
:param targets_real: target tensor for real loss calculation
|
||||
:param targets_fake: target tensor for fake loss calculation
|
||||
:param global_step: the global step variable
|
||||
:param z_generator: callable function that returns a z variable
|
||||
:return: the calculated loss
|
||||
"""
|
||||
with tf.GradientTape() as tape:
|
||||
xd_result = tf.squeeze(x_discriminator(inputs))
|
||||
xd_real_loss = binary_crossentropy(targets_real, xd_result)
|
||||
|
||||
z = z_generator()
|
||||
x_fake = decoder(z)
|
||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||
xd_fake_loss = binary_crossentropy(targets_fake, xd_result)
|
||||
|
||||
_xd_train_loss = xd_real_loss + xd_fake_loss
|
||||
|
||||
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
|
||||
optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
|
||||
global_step=global_step)
|
||||
|
||||
return _xd_train_loss
|
||||
|
||||
|
||||
def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
||||
optimizer: tf.train.Optimizer,
|
||||
targets: tf.Tensor, global_step: tf.Variable,
|
||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||
"""
|
||||
Trains the decoder model for one step (one batch).
|
||||
|
||||
:param decoder: instance of decoder model
|
||||
:param x_discriminator: instance of the x discriminator model
|
||||
:param optimizer: instance of chosen optimizer
|
||||
:param targets: target tensor for loss calculation
|
||||
:param global_step: the global step variable
|
||||
:param z_generator: callable function that returns a z variable
|
||||
:return: the calculated loss
|
||||
"""
|
||||
with tf.GradientTape() as tape:
|
||||
z = z_generator()
|
||||
|
||||
x_fake = decoder(z)
|
||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||
_decoder_train_loss = binary_crossentropy(targets, xd_result)
|
||||
|
||||
grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
||||
optimizer.apply_gradients(zip(grads, decoder.trainable_variables),
|
||||
global_step=global_step)
|
||||
|
||||
return _decoder_train_loss
|
||||
|
||||
|
||||
def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
||||
optimizer: tf.train.Optimizer,
|
||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
|
||||
"""
|
||||
Trains the z discriminator one step (one batch).
|
||||
|
||||
:param z_discriminator: instance of z discriminator model
|
||||
:param encoder: instance of encoder model
|
||||
:param optimizer: instance of chosen optimizer
|
||||
:param inputs: inputs from dataset
|
||||
:param targets_real: target tensor for real loss calculation
|
||||
:param targets_fake: target tensor for fake loss calculation
|
||||
:param global_step: the global step variable
|
||||
:param z_generator: callable function that returns a z variable
|
||||
:return: the calculated loss
|
||||
"""
|
||||
with tf.GradientTape() as tape:
|
||||
z = z_generator()
|
||||
|
||||
zd_result = tf.squeeze(z_discriminator(z))
|
||||
zd_real_loss = binary_crossentropy(targets_real, zd_result)
|
||||
|
||||
z = tf.squeeze(encoder(inputs))
|
||||
zd_result = tf.squeeze(z_discriminator(z))
|
||||
zd_fake_loss = binary_crossentropy(targets_fake, zd_result)
|
||||
|
||||
_zd_train_loss = zd_real_loss + zd_fake_loss
|
||||
|
||||
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
|
||||
optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
|
||||
global_step=global_step)
|
||||
|
||||
return _zd_train_loss
|
||||
|
||||
|
||||
def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
||||
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
||||
targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Trains the encoder and decoder jointly for one step (one batch).
|
||||
|
||||
:param encoder: instance of encoder model
|
||||
:param decoder: instance of decoder model
|
||||
:param z_discriminator: instance of z discriminator model
|
||||
:param optimizer: instance of chosen optimizer
|
||||
:param inputs: inputs from dataset
|
||||
:param targets: target tensor for loss calculation
|
||||
:param global_step: the global step variable
|
||||
:return: tuple of encoder loss, reconstruction loss, reconstructed input
|
||||
"""
|
||||
with tf.GradientTape() as tape:
|
||||
z = encoder(inputs)
|
||||
x_decoded = decoder(z)
|
||||
|
||||
zd_result = tf.squeeze(z_discriminator(tf.squeeze(z)))
|
||||
encoder_loss = binary_crossentropy(targets, zd_result) * 2.0
|
||||
reconstruction_loss = binary_crossentropy(inputs, x_decoded)
|
||||
_enc_dec_train_loss = encoder_loss + reconstruction_loss
|
||||
|
||||
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
||||
encoder.trainable_variables + decoder.trainable_variables)
|
||||
optimizer.apply_gradients(zip(enc_dec_grads,
|
||||
encoder.trainable_variables + decoder.trainable_variables),
|
||||
global_step=global_step)
|
||||
|
||||
return encoder_loss, reconstruction_loss, x_decoded
|
||||
|
||||
|
||||
def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
||||
"""
|
||||
Creates and returns a z variable taken from a normal distribution.
|
||||
|
||||
:param batch_size: size of the batch
|
||||
:param zsize: size of the z latent space
|
||||
:return: created variable
|
||||
"""
|
||||
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
|
||||
return k.variable(z)
|
||||
|
||||
|
||||
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable:
|
||||
"""
|
||||
Extracts a batch from data.
|
||||
|
||||
Reference in New Issue
Block a user