@ -41,6 +41,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
folds: int = 5):
|
||||
"""
|
||||
Train AAE for mnist data set.
|
||||
|
||||
:param folding_id: id of fold used for test data
|
||||
:param inlier_classes: list of class ids that are considered inliers
|
||||
:param total_classes: total number of classes
|
||||
@ -76,6 +77,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Converts a list of pairs to a numpy array.
|
||||
|
||||
:param list_of_pairs: list of pairs
|
||||
:return: numpy array
|
||||
"""
|
||||
@ -129,6 +131,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
def shuffle(train_data: np.ndarray) -> None:
|
||||
"""
|
||||
Shuffles the given training data inplace.
|
||||
|
||||
:param train_data: numpy array of training data
|
||||
"""
|
||||
np.take(train_data, np.random.permutation(train_data.shape[0]), axis=0, out=train_data)
|
||||
@ -388,6 +391,7 @@ def get_z_variable(batch_size: int, zsize: int) -> tf.Variable:
|
||||
def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tf.Variable:
|
||||
"""
|
||||
Extracts a batch from data.
|
||||
|
||||
:param data: numpy array of data
|
||||
:param it: current iteration in epoch (or batch number)
|
||||
:param batch_size: size of batch
|
||||
|
||||
Reference in New Issue
Block a user