Removed hard-coded reliance on specific batch size
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -51,6 +51,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
zsize: int = 32,
|
||||
lr: float = 0.002,
|
||||
train_epoch: int = 80,
|
||||
batch_size: int = 128,
|
||||
verbose: bool = True) -> None:
|
||||
"""
|
||||
Trains auto-encoder for given data set.
|
||||
@ -70,6 +71,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
zsize: size of the intermediary z (default: 32)
|
||||
lr: initial learning rate (default: 0.002)
|
||||
train_epoch: number of epochs to train (default: 80)
|
||||
batch_size: size of each batch (default: 128)
|
||||
verbose: if True prints train progress info to console (default: True)
|
||||
"""
|
||||
|
||||
@ -114,6 +116,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
_epoch = epoch + previous_epochs
|
||||
outputs = _train_one_epoch_simple(_epoch, dataset,
|
||||
verbose=verbose,
|
||||
batch_size=batch_size,
|
||||
**checkpointables)
|
||||
|
||||
if verbose:
|
||||
@ -136,6 +139,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
def _train_one_epoch_simple(epoch: int,
|
||||
dataset: tf.data.Dataset,
|
||||
verbose: bool,
|
||||
batch_size: int,
|
||||
learning_rate_var: tf.Variable,
|
||||
decoder: model.Decoder,
|
||||
encoder: model.Encoder,
|
||||
@ -167,8 +171,8 @@ def _train_one_epoch_simple(epoch: int,
|
||||
enc_dec_loss_avg(reconstruction_loss)
|
||||
|
||||
if int(global_step % LOG_FREQUENCY) == 0:
|
||||
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||
grid = util.prepare_image(comparison.cpu(), nrow=64)
|
||||
comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0)
|
||||
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
|
||||
summary_ops_v2.image(name='reconstruction',
|
||||
tensor=K.expand_dims(grid, axis=0), max_images=1,
|
||||
step=global_step)
|
||||
|
||||
@ -167,7 +167,7 @@ def train(dataset: tf.data.Dataset,
|
||||
_epoch = epoch + previous_epochs
|
||||
outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real,
|
||||
targets_fake=y_fake, z_generator=z_generator,
|
||||
verbose=verbose,
|
||||
verbose=verbose, batch_size=batch_size,
|
||||
**checkpointables)
|
||||
|
||||
if verbose:
|
||||
@ -249,6 +249,7 @@ def _train_one_epoch(epoch: int,
|
||||
dataset: tf.data.Dataset,
|
||||
targets_real: tf.Tensor,
|
||||
verbose: bool,
|
||||
batch_size: int,
|
||||
targets_fake: tf.Tensor,
|
||||
z_generator: Callable[[], tf.Variable],
|
||||
learning_rate_var: tf.Variable,
|
||||
@ -335,8 +336,8 @@ def _train_one_epoch(epoch: int,
|
||||
encoder_loss_avg(encoder_loss)
|
||||
|
||||
if int(global_step % LOG_FREQUENCY) == 0:
|
||||
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||
grid = util.prepare_image(comparison.cpu(), nrow=64)
|
||||
comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0)
|
||||
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
|
||||
summary_ops_v2.image(name='reconstruction',
|
||||
tensor=K.expand_dims(grid, axis=0), max_images=1,
|
||||
step=global_step)
|
||||
|
||||
Reference in New Issue
Block a user