diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index a106fcd..d6b585c 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -51,6 +51,7 @@ def train_simple(dataset: tf.data.Dataset, zsize: int = 32, lr: float = 0.002, train_epoch: int = 80, + batch_size: int = 128, verbose: bool = True) -> None: """ Trains auto-encoder for given data set. @@ -70,6 +71,7 @@ def train_simple(dataset: tf.data.Dataset, zsize: size of the intermediary z (default: 32) lr: initial learning rate (default: 0.002) train_epoch: number of epochs to train (default: 80) + batch_size: size of each batch (default: 128) verbose: if True prints train progress info to console (default: True) """ @@ -114,6 +116,7 @@ def train_simple(dataset: tf.data.Dataset, _epoch = epoch + previous_epochs outputs = _train_one_epoch_simple(_epoch, dataset, verbose=verbose, + batch_size=batch_size, **checkpointables) if verbose: @@ -136,6 +139,7 @@ def train_simple(dataset: tf.data.Dataset, def _train_one_epoch_simple(epoch: int, dataset: tf.data.Dataset, verbose: bool, + batch_size: int, learning_rate_var: tf.Variable, decoder: model.Decoder, encoder: model.Encoder, @@ -167,8 +171,8 @@ def _train_one_epoch_simple(epoch: int, enc_dec_loss_avg(reconstruction_loss) if int(global_step % LOG_FREQUENCY) == 0: - comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0) - grid = util.prepare_image(comparison.cpu(), nrow=64) + comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0) + grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2)) summary_ops_v2.image(name='reconstruction', tensor=K.expand_dims(grid, axis=0), max_images=1, step=global_step) diff --git a/src/twomartens/masterthesis/aae/train_aae.py b/src/twomartens/masterthesis/aae/train_aae.py index 0733afb..b746bb1 100644 --- a/src/twomartens/masterthesis/aae/train_aae.py +++ b/src/twomartens/masterthesis/aae/train_aae.py @@ -167,7 +167,7 @@ def train(dataset: tf.data.Dataset, _epoch = epoch + previous_epochs outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real, targets_fake=y_fake, z_generator=z_generator, - verbose=verbose, + verbose=verbose, batch_size=batch_size, **checkpointables) if verbose: @@ -249,6 +249,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tensor, verbose: bool, + batch_size: int, targets_fake: tf.Tensor, z_generator: Callable[[], tf.Variable], learning_rate_var: tf.Variable, @@ -335,8 +336,8 @@ def _train_one_epoch(epoch: int, encoder_loss_avg(encoder_loss) if int(global_step % LOG_FREQUENCY) == 0: - comparison = K.concatenate([x[:64], x_decoded[:64]], axis=0) - grid = util.prepare_image(comparison.cpu(), nrow=64) + comparison = K.concatenate([x[:batch_size/2], x_decoded[:batch_size/2]], axis=0) + grid = util.prepare_image(comparison.cpu(), nrow=int(batch_size/2)) summary_ops_v2.image(name='reconstruction', tensor=K.expand_dims(grid, axis=0), max_images=1, step=global_step)