Removed hard-coded reliance on specific batch size
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -51,6 +51,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
zsize: int = 32,
|
zsize: int = 32,
|
||||||
lr: float = 0.002,
|
lr: float = 0.002,
|
||||||
train_epoch: int = 80,
|
train_epoch: int = 80,
|
||||||
|
batch_size: int = 128,
|
||||||
verbose: bool = True) -> None:
|
verbose: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Trains auto-encoder for given data set.
|
Trains auto-encoder for given data set.
|
||||||
@ -70,6 +71,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
zsize: size of the intermediary z (default: 32)
|
zsize: size of the intermediary z (default: 32)
|
||||||
lr: initial learning rate (default: 0.002)
|
lr: initial learning rate (default: 0.002)
|
||||||
train_epoch: number of epochs to train (default: 80)
|
train_epoch: number of epochs to train (default: 80)
|
||||||
|
batch_size: size of each batch (default: 128)
|
||||||
verbose: if True prints train progress info to console (default: True)
|
verbose: if True prints train progress info to console (default: True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -114,6 +116,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
_epoch = epoch + previous_epochs
|
_epoch = epoch + previous_epochs
|
||||||
outputs = _train_one_epoch_simple(_epoch, dataset,
|
outputs = _train_one_epoch_simple(_epoch, dataset,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
batch_size=batch_size,
|
||||||
**checkpointables)
|
**checkpointables)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -136,6 +139,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
def _train_one_epoch_simple(epoch: int,
|
def _train_one_epoch_simple(epoch: int,
|
||||||
dataset: tf.data.Dataset,
|
dataset: tf.data.Dataset,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
batch_size: int,
|
||||||
learning_rate_var: tf.Variable,
|
learning_rate_var: tf.Variable,
|
||||||
decoder: model.Decoder,
|
decoder: model.Decoder,
|
||||||
encoder: model.Encoder,
|
encoder: model.Encoder,
|
||||||
@ -167,8 +171,8 @@ def _train_one_epoch_simple(epoch: int,
|
|||||||
enc_dec_loss_avg(reconstruction_loss)
|
enc_dec_loss_avg(reconstruction_loss)
|
||||||
|
|
||||||
if int(global_step % LOG_FREQUENCY) == 0:
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0)
|
comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0)
|
||||||
grid = util.prepare_image(comparison.cpu(), nrow=64)
|
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
|
||||||
summary_ops_v2.image(name='reconstruction',
|
summary_ops_v2.image(name='reconstruction',
|
||||||
tensor=K.expand_dims(grid, axis=0), max_images=1,
|
tensor=K.expand_dims(grid, axis=0), max_images=1,
|
||||||
step=global_step)
|
step=global_step)
|
||||||
|
|||||||
@ -167,7 +167,7 @@ def train(dataset: tf.data.Dataset,
|
|||||||
_epoch = epoch + previous_epochs
|
_epoch = epoch + previous_epochs
|
||||||
outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real,
|
outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real,
|
||||||
targets_fake=y_fake, z_generator=z_generator,
|
targets_fake=y_fake, z_generator=z_generator,
|
||||||
verbose=verbose,
|
verbose=verbose, batch_size=batch_size,
|
||||||
**checkpointables)
|
**checkpointables)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -249,6 +249,7 @@ def _train_one_epoch(epoch: int,
|
|||||||
dataset: tf.data.Dataset,
|
dataset: tf.data.Dataset,
|
||||||
targets_real: tf.Tensor,
|
targets_real: tf.Tensor,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
batch_size: int,
|
||||||
targets_fake: tf.Tensor,
|
targets_fake: tf.Tensor,
|
||||||
z_generator: Callable[[], tf.Variable],
|
z_generator: Callable[[], tf.Variable],
|
||||||
learning_rate_var: tf.Variable,
|
learning_rate_var: tf.Variable,
|
||||||
@ -335,8 +336,8 @@ def _train_one_epoch(epoch: int,
|
|||||||
encoder_loss_avg(encoder_loss)
|
encoder_loss_avg(encoder_loss)
|
||||||
|
|
||||||
if int(global_step % LOG_FREQUENCY) == 0:
|
if int(global_step % LOG_FREQUENCY) == 0:
|
||||||
comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0)
|
comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0)
|
||||||
grid = util.prepare_image(comparison.cpu(), nrow=64)
|
grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2))
|
||||||
summary_ops_v2.image(name='reconstruction',
|
summary_ops_v2.image(name='reconstruction',
|
||||||
tensor=K.expand_dims(grid, axis=0), max_images=1,
|
tensor=K.expand_dims(grid, axis=0), max_images=1,
|
||||||
step=global_step)
|
step=global_step)
|
||||||
|
|||||||
Reference in New Issue
Block a user