Corrected default values for train and run functions
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -39,9 +39,9 @@ tfe = tf.contrib.eager
|
|||||||
def run_simple(dataset: tf.data.Dataset,
|
def run_simple(dataset: tf.data.Dataset,
|
||||||
iteration: int,
|
iteration: int,
|
||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
channels: int = 1,
|
channels: int = 3,
|
||||||
zsize: int = 32,
|
zsize: int = 64,
|
||||||
batch_size: int = 128,
|
batch_size: int = 16,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
debug: bool = False) -> None:
|
debug: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
@ -53,9 +53,9 @@ def run_simple(dataset: tf.data.Dataset,
|
|||||||
dataset: run dataset
|
dataset: run dataset
|
||||||
iteration: identifier for the used training run
|
iteration: identifier for the used training run
|
||||||
weights_prefix: prefix for trained weights directory
|
weights_prefix: prefix for trained weights directory
|
||||||
channels: number of channels in input image (default: 1)
|
channels: number of channels in input image (default: 3)
|
||||||
zsize: size of the intermediary z (default: 32)
|
zsize: size of the intermediary z (default: 64)
|
||||||
batch_size: size of each batch (default: 128)
|
batch_size: size of each batch (default: 16)
|
||||||
verbose: if True training progress is printed to console (default: False)
|
verbose: if True training progress is printed to console (default: False)
|
||||||
debug: if True summaries are collected (default: False)
|
debug: if True summaries are collected (default: False)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -47,11 +47,11 @@ LOG_FREQUENCY: int = 10
|
|||||||
def train_simple(dataset: tf.data.Dataset,
|
def train_simple(dataset: tf.data.Dataset,
|
||||||
iteration: int,
|
iteration: int,
|
||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
channels: int = 1,
|
channels: int = 3,
|
||||||
zsize: int = 32,
|
zsize: int = 64,
|
||||||
lr: float = 0.002,
|
lr: float = 0.0001,
|
||||||
train_epoch: int = 80,
|
train_epoch: int = 1,
|
||||||
batch_size: int = 128,
|
batch_size: int = 16,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
debug: bool = False) -> None:
|
debug: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
@ -68,11 +68,11 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
dataset: train dataset
|
dataset: train dataset
|
||||||
iteration: identifier for the current training run
|
iteration: identifier for the current training run
|
||||||
weights_prefix: prefix for weights directory
|
weights_prefix: prefix for weights directory
|
||||||
channels: number of channels in input image (default: 1)
|
channels: number of channels in input image (default: 3)
|
||||||
zsize: size of the intermediary z (default: 32)
|
zsize: size of the intermediary z (default: 64)
|
||||||
lr: initial learning rate (default: 0.002)
|
lr: initial learning rate (default: 0.0001)
|
||||||
train_epoch: number of epochs to train (default: 80)
|
train_epoch: number of epochs to train (default: 1)
|
||||||
batch_size: size of each batch (default: 128)
|
batch_size: size of each batch (default: 16)
|
||||||
verbose: if True training progress is printed to console (default: False)
|
verbose: if True training progress is printed to console (default: False)
|
||||||
debug: if True summaries are collected (default: False)
|
debug: if True summaries are collected (default: False)
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user