From c82f5bf87c7c71d122824fb387f813eaff2cb9c0 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 06:38:27 +0100 Subject: [PATCH] Fixed docstrings Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index ea2288c..fb7d84c 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -41,6 +41,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i folds: int = 5): """ Train AAE for mnist data set. + :param folding_id: id of fold used for test data :param inlier_classes: list of class ids that are considered inliers :param total_classes: total number of classes @@ -76,6 +77,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]: """ Converts a list of pairs to a numpy array. + :param list_of_pairs: list of pairs :return: numpy array """ @@ -129,6 +131,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i def shuffle(train_data: np.ndarray) -> None: """ Shuffles the given training data inplace. + :param train_data: numpy array of training data """ np.take(train_data, np.random.permutation(train_data.shape[0]), axis=0, out=train_data) @@ -388,6 +391,7 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable: def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable: """ Extracts a batch from data. + :param data: numpy array of data :param it: current iteration in epoch (or batch number) :param batch_size: size of batch