Removed hard-coded reliance on specific batch size

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-09 10:33:14 +02:00
parent be501d606f
commit 523a9c8e71
2 changed files with 10 additions and 5 deletions

View File

@ -51,6 +51,7 @@ def train_simple(dataset: tf.data.Dataset,
zsize: int = 32, zsize: int = 32,
lr: float = 0.002, lr: float = 0.002,
train_epoch: int = 80, train_epoch: int = 80,
batch_size: int = 128,
verbose: bool = True) -> None: verbose: bool = True) -> None:
""" """
Trains auto-encoder for given data set. Trains auto-encoder for given data set.
@ -70,6 +71,7 @@ def train_simple(dataset: tf.data.Dataset,
zsize: size of the intermediary z (default: 32) zsize: size of the intermediary z (default: 32)
lr: initial learning rate (default: 0.002) lr: initial learning rate (default: 0.002)
train_epoch: number of epochs to train (default: 80) train_epoch: number of epochs to train (default: 80)
batch_size: size of each batch (default: 128)
verbose: if True prints train progress info to console (default: True) verbose: if True prints train progress info to console (default: True)
""" """
@ -114,6 +116,7 @@ def train_simple(dataset: tf.data.Dataset,
_epoch = epoch + previous_epochs _epoch = epoch + previous_epochs
outputs = _train_one_epoch_simple(_epoch, dataset, outputs = _train_one_epoch_simple(_epoch, dataset,
verbose=verbose, verbose=verbose,
batch_size=batch_size,
**checkpointables) **checkpointables)
if verbose: if verbose:
@ -136,6 +139,7 @@ def train_simple(dataset: tf.data.Dataset,
def _train_one_epoch_simple(epoch: int, def _train_one_epoch_simple(epoch: int,
dataset: tf.data.Dataset, dataset: tf.data.Dataset,
verbose: bool, verbose: bool,
batch_size: int,
learning_rate_var: tf.Variable, learning_rate_var: tf.Variable,
decoder: model.Decoder, decoder: model.Decoder,
encoder: model.Encoder, encoder: model.Encoder,
@ -167,8 +171,8 @@ def _train_one_epoch_simple(epoch: int,
enc_dec_loss_avg(reconstruction_loss) enc_dec_loss_avg(reconstruction_loss)
if int(global_step % LOG_FREQUENCY) == 0: if int(global_step % LOG_FREQUENCY) == 0:
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0) comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0)
grid = util.prepare_image(comparison.cpu(), nrow=64) grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
summary_ops_v2.image(name='reconstruction', summary_ops_v2.image(name='reconstruction',
tensor=K.expand_dims(grid, axis=0), max_images=1, tensor=K.expand_dims(grid, axis=0), max_images=1,
step=global_step) step=global_step)

View File

@ -167,7 +167,7 @@ def train(dataset: tf.data.Dataset,
_epoch = epoch + previous_epochs _epoch = epoch + previous_epochs
outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real, outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real,
targets_fake=y_fake, z_generator=z_generator, targets_fake=y_fake, z_generator=z_generator,
verbose=verbose, verbose=verbose, batch_size=batch_size,
**checkpointables) **checkpointables)
if verbose: if verbose:
@ -249,6 +249,7 @@ def _train_one_epoch(epoch: int,
dataset: tf.data.Dataset, dataset: tf.data.Dataset,
targets_real: tf.Tensor, targets_real: tf.Tensor,
verbose: bool, verbose: bool,
batch_size: int,
targets_fake: tf.Tensor, targets_fake: tf.Tensor,
z_generator: Callable[[], tf.Variable], z_generator: Callable[[], tf.Variable],
learning_rate_var: tf.Variable, learning_rate_var: tf.Variable,
@ -335,8 +336,8 @@ def _train_one_epoch(epoch: int,
encoder_loss_avg(encoder_loss) encoder_loss_avg(encoder_loss)
if int(global_step % LOG_FREQUENCY) == 0: if int(global_step % LOG_FREQUENCY) == 0:
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0) comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0)
grid = util.prepare_image(comparison.cpu(), nrow=64) grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
summary_ops_v2.image(name='reconstruction', summary_ops_v2.image(name='reconstruction',
tensor=K.expand_dims(grid, axis=0), max_images=1, tensor=K.expand_dims(grid, axis=0), max_images=1,
step=global_step) step=global_step)