Extracted training steps of models into separate functions

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 06:37:45 +01:00
parent 887435ddd7
commit 14cb27afd7

View File

@ -15,11 +15,12 @@
# limitations under the License. # limitations under the License.
"""aae.train.py: contains training functionality""" """aae.train.py: contains training functionality"""
import functools
import os import os
import pickle import pickle
import random import random
import time import time
from typing import Sequence, Tuple from typing import Callable, Sequence, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -101,6 +102,9 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
y_real_z = k.ones(batch_size) y_real_z = k.ones(batch_size)
y_fake_z = k.zeros(batch_size) y_fake_z = k.zeros(batch_size)
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1) sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
z_generator = functools.partial(get_z_variable, batch_size=batch_size, zsize=zsize)
global_step_decoder = k.variable(0) global_step_decoder = k.variable(0)
global_step_enc_dec = k.variable(0) global_step_enc_dec = k.variable(0)
global_step_xd = k.variable(0) global_step_xd = k.variable(0)
@ -140,75 +144,48 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
for it in range(nr_batches): for it in range(nr_batches):
x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size)) x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size))
# x discriminator # x discriminator
with tf.GradientTape() as tape: _xd_train_loss = train_xdiscriminator_step(x_discriminator=x_discriminator,
xd_result = tf.squeeze(x_discriminator(x)) decoder=decoder,
xd_real_loss = binary_crossentropy(y_real, xd_result) optimizer=x_discriminator_optimizer,
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize)) inputs=x,
z = k.variable(z) targets_real=y_real,
targets_fake=y_fake,
x_fake = decoder(z) global_step=global_step_xd,
xd_result = tf.squeeze(x_discriminator(x_fake)) z_generator=z_generator)
xd_fake_loss = binary_crossentropy(y_fake, xd_result)
_xd_train_loss = xd_real_loss + xd_fake_loss
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
x_discriminator_optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
global_step=global_step_xd)
xd_loss_avg(_xd_train_loss) xd_loss_avg(_xd_train_loss)
# -------- # --------
# decoder # decoder
with tf.GradientTape() as tape: _decoder_train_loss = train_decoder_step(decoder=decoder,
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize)) x_discriminator=x_discriminator,
z = k.variable(z) optimizer=decoder_optimizer,
targets=y_real,
x_fake = decoder(z) global_step=global_step_decoder,
xd_result = tf.squeeze(x_discriminator(x_fake)) z_generator=z_generator)
_decoder_train_loss = binary_crossentropy(y_real, xd_result)
decoder_grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
decoder_optimizer.apply_gradients(zip(decoder_grads, decoder.trainable_variables),
global_step=global_step_decoder)
decoder_loss_avg(_decoder_train_loss) decoder_loss_avg(_decoder_train_loss)
# --------- # ---------
# z discriminator # z discriminator
with tf.GradientTape() as tape: _zd_train_loss = train_zdiscriminator_step(z_discriminator=z_discriminator,
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, zsize)) encoder=encoder,
z = k.variable(z) optimizer=z_discriminator_optimizer,
inputs=x,
zd_result = tf.squeeze(z_discriminator(z)) targets_real=y_real,
zd_real_loss = binary_crossentropy(y_real_z, zd_result) targets_fake=y_fake,
global_step=global_step_zd,
z = tf.squeeze(encoder(x)) z_generator=z_generator)
zd_result = tf.squeeze(z_discriminator(z))
zd_fake_loss = binary_crossentropy(y_fake_z, zd_result)
_zd_train_loss = zd_real_loss + zd_fake_loss
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
z_discriminator_optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
global_step=global_step_zd)
zd_loss_avg(_zd_train_loss) zd_loss_avg(_zd_train_loss)
# ----------- # -----------
# encoder + decoder # encoder + decoder
with tf.GradientTape() as tape: encoder_loss, reconstruction_loss, x_decoded = train_enc_dec_step(encoder=encoder,
z = encoder(x) decoder=decoder,
x_decoded = decoder(z) z_discriminator=z_discriminator,
optimizer=enc_dec_optimizer,
zd_result = tf.squeeze(z_discriminator(tf.squeeze(z))) inputs=x,
encoder_loss = binary_crossentropy(y_real_z, zd_result) * 2.0 targets=y_real,
recovery_loss = binary_crossentropy(x, x_decoded) global_step=global_step_enc_dec)
_enc_dec_train_loss = encoder_loss + recovery_loss enc_dec_loss_avg(reconstruction_loss)
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
encoder.trainable_variables + decoder.trainable_variables)
enc_dec_optimizer.apply_gradients(zip(enc_dec_grads,
encoder.trainable_variables + decoder.trainable_variables),
global_step=global_step_enc_dec)
enc_dec_loss_avg(recovery_loss)
encoder_loss_avg(encoder_loss) encoder_loss_avg(encoder_loss)
if it == 0: if it == 0:
@ -261,6 +238,153 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
pickle.dump(zd_loss_history, file) pickle.dump(zd_loss_history, file)
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
"""
Trains the x discriminator model for one step (one batch).
:param x_discriminator: instance of x discriminator model
:param decoder: instance of decoder model
:param optimizer: instance of chosen optimizer
:param inputs: inputs from dataset
:param targets_real: target tensor for real loss calculation
:param targets_fake: target tensor for fake loss calculation
:param global_step: the global step variable
:param z_generator: callable function that returns a z variable
:return: the calculated loss
"""
with tf.GradientTape() as tape:
xd_result = tf.squeeze(x_discriminator(inputs))
xd_real_loss = binary_crossentropy(targets_real, xd_result)
z = z_generator()
x_fake = decoder(z)
xd_result = tf.squeeze(x_discriminator(x_fake))
xd_fake_loss = binary_crossentropy(targets_fake, xd_result)
_xd_train_loss = xd_real_loss + xd_fake_loss
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
global_step=global_step)
return _xd_train_loss
def train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
optimizer: tf.train.Optimizer,
targets: tf.Tensor, global_step: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
"""
Trains the decoder model for one step (one batch).
:param decoder: instance of decoder model
:param x_discriminator: instance of the x discriminator model
:param optimizer: instance of chosen optimizer
:param targets: target tensor for loss calculation
:param global_step: the global step variable
:param z_generator: callable function that returns a z variable
:return: the calculated loss
"""
with tf.GradientTape() as tape:
z = z_generator()
x_fake = decoder(z)
xd_result = tf.squeeze(x_discriminator(x_fake))
_decoder_train_loss = binary_crossentropy(targets, xd_result)
grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
optimizer.apply_gradients(zip(grads, decoder.trainable_variables),
global_step=global_step)
return _decoder_train_loss
def train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
optimizer: tf.train.Optimizer,
inputs: tf.Tensor, targets_real: tf.Tensor,
targets_fake: tf.Tensor, global_step: tf.Variable,
z_generator: Callable[[], tf.Variable]) -> tf.Tensor:
"""
Trains the z discriminator one step (one batch).
:param z_discriminator: instance of z discriminator model
:param encoder: instance of encoder model
:param optimizer: instance of chosen optimizer
:param inputs: inputs from dataset
:param targets_real: target tensor for real loss calculation
:param targets_fake: target tensor for fake loss calculation
:param global_step: the global step variable
:param z_generator: callable function that returns a z variable
:return: the calculated loss
"""
with tf.GradientTape() as tape:
z = z_generator()
zd_result = tf.squeeze(z_discriminator(z))
zd_real_loss = binary_crossentropy(targets_real, zd_result)
z = tf.squeeze(encoder(inputs))
zd_result = tf.squeeze(z_discriminator(z))
zd_fake_loss = binary_crossentropy(targets_fake, zd_result)
_zd_train_loss = zd_real_loss + zd_fake_loss
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
global_step=global_step)
return _zd_train_loss
def train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
targets: tf.Tensor, global_step: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Trains the encoder and decoder jointly for one step (one batch).
:param encoder: instance of encoder model
:param decoder: instance of decoder model
:param z_discriminator: instance of z discriminator model
:param optimizer: instance of chosen optimizer
:param inputs: inputs from dataset
:param targets: target tensor for loss calculation
:param global_step: the global step variable
:return: tuple of encoder loss, reconstruction loss, reconstructed input
"""
with tf.GradientTape() as tape:
z = encoder(inputs)
x_decoded = decoder(z)
zd_result = tf.squeeze(z_discriminator(tf.squeeze(z)))
encoder_loss = binary_crossentropy(targets, zd_result) * 2.0
reconstruction_loss = binary_crossentropy(inputs, x_decoded)
_enc_dec_train_loss = encoder_loss + reconstruction_loss
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
encoder.trainable_variables + decoder.trainable_variables)
optimizer.apply_gradients(zip(enc_dec_grads,
encoder.trainable_variables + decoder.trainable_variables),
global_step=global_step)
return encoder_loss, reconstruction_loss, x_decoded
def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
"""
Creates and returns a z variable taken from a normal distribution.
:param batch_size: size of the batch
:param zsize: size of the z latent space
:return: created variable
"""
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
return k.variable(z)
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable: def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable:
""" """
Extracts a batch from data. Extracts a batch from data.