diff --git a/src/twomartens/masterthesis/aae/run.py b/src/twomartens/masterthesis/aae/run.py index 4ad3164..f78d59e 100644 --- a/src/twomartens/masterthesis/aae/run.py +++ b/src/twomartens/masterthesis/aae/run.py @@ -39,9 +39,9 @@ tfe = tf.contrib.eager def run_simple(dataset: tf.data.Dataset, iteration: int, weights_prefix: str, - channels: int = 1, - zsize: int = 32, - batch_size: int = 128, + channels: int = 3, + zsize: int = 64, + batch_size: int = 16, verbose: bool = False, debug: bool = False) -> None: """ @@ -53,9 +53,9 @@ def run_simple(dataset: tf.data.Dataset, dataset: run dataset iteration: identifier for the used training run weights_prefix: prefix for trained weights directory - channels: number of channels in input image (default: 1) - zsize: size of the intermediary z (default: 32) - batch_size: size of each batch (default: 128) + channels: number of channels in input image (default: 3) + zsize: size of the intermediary z (default: 64) + batch_size: size of each batch (default: 16) verbose: if True training progress is printed to console (default: False) debug: if True summaries are collected (default: False) """ diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 43e666c..5f9037b 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -47,11 +47,11 @@ LOG_FREQUENCY: int = 10 def train_simple(dataset: tf.data.Dataset, iteration: int, weights_prefix: str, - channels: int = 1, - zsize: int = 32, - lr: float = 0.002, - train_epoch: int = 80, - batch_size: int = 128, + channels: int = 3, + zsize: int = 64, + lr: float = 0.0001, + train_epoch: int = 1, + batch_size: int = 16, verbose: bool = False, debug: bool = False) -> None: """ @@ -68,11 +68,11 @@ def train_simple(dataset: tf.data.Dataset, dataset: train dataset iteration: identifier for the current training run weights_prefix: prefix for weights directory - channels: number of channels in input image (default: 1) - zsize: size of the intermediary z (default: 32) - lr: initial learning rate (default: 0.002) - train_epoch: number of epochs to train (default: 80) - batch_size: size of each batch (default: 128) + channels: number of channels in input image (default: 3) + zsize: size of the intermediary z (default: 64) + lr: initial learning rate (default: 0.0001) + train_epoch: number of epochs to train (default: 1) + batch_size: size of each batch (default: 16) verbose: if True training progress is printed to console (default: False) debug: if True summaries are collected (default: False) """