diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 0e0e6fe..0d0f244 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -58,12 +58,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i mnist_valid = [] for i in range(folds): - if i != folding_id: + if i != folding_id: # exclude testing fold, representing 20% of each class with open('data/data_fold_%d.pkl' % i, 'rb') as pkl: fold = pickle.load(pkl) - if len(mnist_valid) == 0: + if len(mnist_valid) == 0: # single out one fold, comprising 20% of each class mnist_valid = fold - else: + else: # form train set from remaining folds, comprising 60% of each class mnist_train += fold outlier_classes = []