Fixed docstrings

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 06:38:27 +01:00
parent 14cb27afd7
commit c82f5bf87c

View File

@ -41,6 +41,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
folds: int = 5):
"""
Train AAE for mnist data set.
:param folding_id: id of fold used for test data
:param inlier_classes: list of class ids that are considered inliers
:param total_classes: total number of classes
@ -76,6 +77,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
"""
Converts a list of pairs to a numpy array.
:param list_of_pairs: list of pairs
:return: numpy array
"""
@ -129,6 +131,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
def shuffle(train_data: np.ndarray) -> None:
"""
Shuffles the given training data inplace.
:param train_data: numpy array of training data
"""
np.take(train_data, np.random.permutation(train_data.shape[0]), axis=0, out=train_data)
@ -388,6 +391,7 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable:
"""
Extracts a batch from data.
:param data: numpy array of data
:param it: current iteration in epoch (or batch number)
:param batch_size: size of batch