diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index af63011..078e498 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -38,33 +38,27 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy GRACE = 10 -def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int, - iteration: int, - channels: int = 1, zsize: int = 32, lr: float = 0.002, - batch_size: int = 128, train_epoch: int = 80, - folds: int = 5, verbose: bool = True): +def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int, + fold_prefix: str = 'data/data_fold_', + batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]: """ - Train AAE for mnist data set. + Prepares the MNIST data - :param folding_id: id of fold used for test data + :param test_fold_id: id of test fold :param inlier_classes: list of class ids that are considered inliers :param total_classes: total number of classes - :param iteration: identifier for the current training run - :param channels: number of channels in input image - :param zsize: size of the intermediary z - :param lr: learning rate - :param batch_size: size of each batch - :param train_epoch: number of epochs to train - :param folds: number of folds available - :param verbose: if True prints train progress info to console + :param fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_') + :param batch_size: size of batch (default: 128) + :param folds: number of folds (default: 5) + :return: tuple(train dataset, valid dataset) """ # prepare data mnist_train = [] mnist_valid = [] for i in range(folds): - if i != folding_id: # exclude testing fold, representing 20% of each class - with open('data/data_fold_%d.pkl' % i, 'rb') as pkl: + if i != test_fold_id: # exclude testing fold, representing 20% of each class + with open(f"{fold_prefix}{i:d}.pkl", 'rb') as pkl: fold = pickle.load(pkl) if len(mnist_valid) == 0: # single out one fold, comprising 20% of each class mnist_valid = fold @@ -82,17 +76,43 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]: """ Converts a list of pairs to a numpy array. - + :param list_of_pairs: list of pairs :return: numpy array """ return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int) mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train) - + mnist_valid_x, mnist_valid_y = list_of_pairs_to_numpy(mnist_valid) + # get dataset - dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y)) - dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize) + train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y)) + train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, + drop_remainder=True).map(normalize) + valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y)) + valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size, + drop_remainder=True).map(normalize) + + return train_dataset, valid_dataset + + +def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, + channels: int = 1, zsize: int = 32, lr: float = 0.002, + batch_size: int = 128, train_epoch: int = 80, + verbose: bool = True) -> None: + """ + Train AAE for given data set. + + :param dataset: train dataset + :param iteration: identifier for the current training run + :param result_prefix: prefix for result images + :param channels: number of channels in input image (default: 1) + :param zsize: size of the intermediary z (default: 32) + :param lr: initial learning rate (default: 0.002) + :param batch_size: the size of each batch (default: 128) + :param train_epoch: number of epochs to train (default: 80) + :param verbose: if True prints train progress info to console (default: True) + """ # get models encoder = Encoder(zsize) @@ -226,7 +246,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i tensor=k.expand_dims(grid, axis=0), max_images=1, step=global_step_decoder) from PIL import Image - filename = 'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png' + filename = os.path.join(result_prefix, 'reconstruction_' + str(epoch) + '.png') ndarr = grid.cpu().numpy() im = Image.fromarray(ndarr) im.save(filename) @@ -261,7 +281,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), max_images=1, step=global_step_decoder) from PIL import Image - filename = 'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png' + filename = os.path.join(result_prefix, 'sample_' + str(epoch) + '.png') ndarr = grid.cpu().numpy() im = Image.fromarray(ndarr) im.save(filename) @@ -493,7 +513,9 @@ if __name__ == "__main__": tf.enable_eager_execution() inlier_classes = [0] iteration = 1 + train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=[0], + total_classes=10) train_summary_writer = summary_ops_v2.create_file_writer( './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): - train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10, iteration=iteration) + train(dataset=train_dataset, iteration=iteration, result_prefix='results' + str(inlier_classes[0]) + '/')