Improved imports according to Google standards
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -41,18 +41,20 @@ import math
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Dict, Sequence, Tuple
|
from typing import Callable
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Sequence
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.ops import summary_ops_v2
|
from tensorflow.python.ops import summary_ops_v2
|
||||||
|
|
||||||
from .model import Decoder, Encoder, XDiscriminator, ZDiscriminator
|
from twomartens.masterthesis.aae import model
|
||||||
from .util import prepare_image
|
from twomartens.masterthesis.aae import util
|
||||||
|
|
||||||
# shortcuts for tensorflow sub packages and classes
|
# shortcuts for tensorflow sub packages and classes
|
||||||
k = tf.keras.backend
|
K = tf.keras.backend
|
||||||
AdamOptimizer = tf.train.AdamOptimizer
|
|
||||||
tfe = tf.contrib.eager
|
tfe = tf.contrib.eager
|
||||||
|
|
||||||
GRACE: int = 10
|
GRACE: int = 10
|
||||||
@ -163,9 +165,9 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# non-preserved tensors
|
# non-preserved tensors
|
||||||
y_real = k.ones(batch_size)
|
y_real = K.ones(batch_size)
|
||||||
y_fake = k.zeros(batch_size)
|
y_fake = K.zeros(batch_size)
|
||||||
sample = k.expand_dims(k.expand_dims(k.random_normal((64, zsize)), axis=1), axis=1)
|
sample = K.expand_dims(K.expand_dims(K.random_normal((64, zsize)), axis=1), axis=1)
|
||||||
# z generator function
|
# z generator function
|
||||||
z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize)
|
z_generator = functools.partial(_get_z_variable, batch_size=batch_size, zsize=zsize)
|
||||||
|
|
||||||
@ -180,28 +182,28 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
|
|
||||||
# checkpointed tensors and variables
|
# checkpointed tensors and variables
|
||||||
checkpointables = {
|
checkpointables = {
|
||||||
'learning_rate_var': k.variable(lr),
|
'learning_rate_var': K.variable(lr),
|
||||||
}
|
}
|
||||||
checkpointables.update({
|
checkpointables.update({
|
||||||
# get models
|
# get models
|
||||||
'encoder': Encoder(zsize),
|
'encoder': model.Encoder(zsize),
|
||||||
'decoder': Decoder(channels),
|
'decoder': model.Decoder(channels),
|
||||||
'z_discriminator': ZDiscriminator(),
|
'z_discriminator': model.ZDiscriminator(),
|
||||||
'x_discriminator': XDiscriminator(),
|
'x_discriminator': model.XDiscriminator(),
|
||||||
# define optimizers
|
# define optimizers
|
||||||
'decoder_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
'decoder_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
||||||
'enc_dec_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
'enc_dec_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999),
|
||||||
'z_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'z_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
'x_discriminator_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
beta1=0.5, beta2=0.999),
|
beta1=0.5, beta2=0.999),
|
||||||
# global step counter
|
# global step counter
|
||||||
'epoch_var': k.variable(-1, dtype=tf.int64),
|
'epoch_var': K.variable(-1, dtype=tf.int64),
|
||||||
'global_step': tf.train.get_or_create_global_step(),
|
'global_step': tf.train.get_or_create_global_step(),
|
||||||
'global_step_decoder': k.variable(0, dtype=tf.int64),
|
'global_step_decoder': K.variable(0, dtype=tf.int64),
|
||||||
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
|
'global_step_enc_dec': K.variable(0, dtype=tf.int64),
|
||||||
'global_step_xd': k.variable(0, dtype=tf.int64),
|
'global_step_xd': K.variable(0, dtype=tf.int64),
|
||||||
'global_step_zd': k.variable(0, dtype=tf.int64),
|
'global_step_zd': K.variable(0, dtype=tf.int64),
|
||||||
})
|
})
|
||||||
|
|
||||||
# checkpoint
|
# checkpoint
|
||||||
@ -243,10 +245,10 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
))
|
))
|
||||||
|
|
||||||
# save sample image summary
|
# save sample image summary
|
||||||
def _save_sample(decoder: Decoder, global_step: tf.Variable, **kwargs) -> None:
|
def _save_sample(decoder: model.Decoder, global_step: tf.Variable, **kwargs) -> None:
|
||||||
resultsample = decoder(sample).cpu()
|
resultsample = decoder(sample).cpu()
|
||||||
grid = prepare_image(resultsample)
|
grid = util.prepare_image(resultsample)
|
||||||
summary_ops_v2.image(name='sample', tensor=k.expand_dims(grid, axis=0),
|
summary_ops_v2.image(name='sample', tensor=K.expand_dims(grid, axis=0),
|
||||||
max_images=1, step=global_step)
|
max_images=1, step=global_step)
|
||||||
|
|
||||||
with summary_ops_v2.always_record_summaries():
|
with summary_ops_v2.always_record_summaries():
|
||||||
@ -310,8 +312,8 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
verbose: bool,
|
verbose: bool,
|
||||||
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable],
|
targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable],
|
||||||
learning_rate_var: tf.Variable,
|
learning_rate_var: tf.Variable,
|
||||||
decoder: Decoder, encoder: Encoder, x_discriminator: XDiscriminator,
|
decoder: model.Decoder, encoder: model.Encoder, x_discriminator: model.XDiscriminator,
|
||||||
z_discriminator: ZDiscriminator, decoder_optimizer: tf.train.Optimizer,
|
z_discriminator: model.ZDiscriminator, decoder_optimizer: tf.train.Optimizer,
|
||||||
x_discriminator_optimizer: tf.train.Optimizer,
|
x_discriminator_optimizer: tf.train.Optimizer,
|
||||||
z_discriminator_optimizer: tf.train.Optimizer,
|
z_discriminator_optimizer: tf.train.Optimizer,
|
||||||
enc_dec_optimizer: tf.train.Optimizer,
|
enc_dec_optimizer: tf.train.Optimizer,
|
||||||
@ -390,10 +392,10 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
encoder_loss_avg(encoder_loss)
|
encoder_loss_avg(encoder_loss)
|
||||||
|
|
||||||
if int(global_step % LOG_FREQUENCY) == 0:
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||||
grid = prepare_image(comparison.cpu(), nrow=64)
|
grid = util.prepare_image(comparison.cpu(), nrow=64)
|
||||||
summary_ops_v2.image(name='reconstruction',
|
summary_ops_v2.image(name='reconstruction',
|
||||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
tensor=K.expand_dims(grid, axis=0), max_images=1,
|
||||||
step=global_step)
|
step=global_step)
|
||||||
global_step.assign_add(1)
|
global_step.assign_add(1)
|
||||||
|
|
||||||
@ -413,7 +415,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
def _train_xdiscriminator_step(x_discriminator: model.XDiscriminator, decoder: model.Decoder,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||||
@ -463,7 +465,7 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder
|
|||||||
return _xd_train_loss
|
return _xd_train_loss
|
||||||
|
|
||||||
|
|
||||||
def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
def _train_decoder_step(decoder: model.Decoder, x_discriminator: model.XDiscriminator,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
targets: tf.Tensor, global_step: tf.Variable,
|
targets: tf.Tensor, global_step: tf.Variable,
|
||||||
global_step_decoder: tf.Variable,
|
global_step_decoder: tf.Variable,
|
||||||
@ -502,7 +504,7 @@ def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
|||||||
return _decoder_train_loss
|
return _decoder_train_loss
|
||||||
|
|
||||||
|
|
||||||
def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder,
|
def _train_zdiscriminator_step(z_discriminator: model.ZDiscriminator, encoder: model.Encoder,
|
||||||
optimizer: tf.train.Optimizer,
|
optimizer: tf.train.Optimizer,
|
||||||
inputs: tf.Tensor, targets_real: tf.Tensor,
|
inputs: tf.Tensor, targets_real: tf.Tensor,
|
||||||
targets_fake: tf.Tensor, global_step: tf.Variable,
|
targets_fake: tf.Tensor, global_step: tf.Variable,
|
||||||
@ -553,7 +555,7 @@ def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder
|
|||||||
return _zd_train_loss
|
return _zd_train_loss
|
||||||
|
|
||||||
|
|
||||||
def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDiscriminator,
|
def _train_enc_dec_step(encoder: model.Encoder, decoder: model.Decoder, z_discriminator: model.ZDiscriminator,
|
||||||
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
optimizer: tf.train.Optimizer, inputs: tf.Tensor,
|
||||||
targets: tf.Tensor, global_step: tf.Variable,
|
targets: tf.Tensor, global_step: tf.Variable,
|
||||||
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
global_step_enc_dec: tf.Variable) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
||||||
@ -608,8 +610,8 @@ def _get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
|||||||
:param zsize: size of the z latent space
|
:param zsize: size of the z latent space
|
||||||
:return: created variable
|
:return: created variable
|
||||||
"""
|
"""
|
||||||
z = k.reshape(k.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
|
z = K.reshape(K.random_normal((batch_size, zsize)), (-1, 1, 1, zsize))
|
||||||
return k.variable(z)
|
return K.variable(z)
|
||||||
|
|
||||||
|
|
||||||
def _normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
def _normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||||
@ -620,7 +622,7 @@ def _normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tens
|
|||||||
:param label: label tensor
|
:param label: label tensor
|
||||||
:return: normalized tensor
|
:return: normalized tensor
|
||||||
"""
|
"""
|
||||||
return k.expand_dims(tf.divide(feature, 255.0)), label
|
return K.expand_dims(tf.divide(feature, 255.0)), label
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -22,7 +22,9 @@ Functions:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import Sequence, Tuple, Union
|
from typing import Sequence
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|||||||
Reference in New Issue
Block a user