Moved dataset preparation into own function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -38,33 +38,27 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
|
||||
GRACE = 10
|
||||
|
||||
|
||||
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||
iteration: int,
|
||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
folds: int = 5, verbose: bool = True):
|
||||
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||
fold_prefix: str = 'data/data_fold_',
|
||||
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||
"""
|
||||
Train AAE for mnist data set.
|
||||
Prepares the MNIST data
|
||||
|
||||
:param folding_id: id of fold used for test data
|
||||
:param test_fold_id: id of test fold
|
||||
:param inlier_classes: list of class ids that are considered inliers
|
||||
:param total_classes: total number of classes
|
||||
:param iteration: identifier for the current training run
|
||||
:param channels: number of channels in input image
|
||||
:param zsize: size of the intermediary z
|
||||
:param lr: learning rate
|
||||
:param batch_size: size of each batch
|
||||
:param train_epoch: number of epochs to train
|
||||
:param folds: number of folds available
|
||||
:param verbose: if True prints train progress info to console
|
||||
:param fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
|
||||
:param batch_size: size of batch (default: 128)
|
||||
:param folds: number of folds (default: 5)
|
||||
:return: tuple(train dataset, valid dataset)
|
||||
"""
|
||||
# prepare data
|
||||
mnist_train = []
|
||||
mnist_valid = []
|
||||
|
||||
for i in range(folds):
|
||||
if i != folding_id: # exclude testing fold, representing 20% of each class
|
||||
with open('data/data_fold_%d.pkl' % i, 'rb') as pkl:
|
||||
if i != test_fold_id: # exclude testing fold, representing 20% of each class
|
||||
with open(f"{fold_prefix}{i:d}.pkl", 'rb') as pkl:
|
||||
fold = pickle.load(pkl)
|
||||
if len(mnist_valid) == 0: # single out one fold, comprising 20% of each class
|
||||
mnist_valid = fold
|
||||
@ -82,17 +76,43 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Converts a list of pairs to a numpy array.
|
||||
|
||||
|
||||
:param list_of_pairs: list of pairs
|
||||
:return: numpy array
|
||||
"""
|
||||
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
|
||||
|
||||
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
||||
|
||||
mnist_valid_x, mnist_valid_y = list_of_pairs_to_numpy(mnist_valid)
|
||||
|
||||
# get dataset
|
||||
dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize)
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||
train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size,
|
||||
drop_remainder=True).map(normalize)
|
||||
valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y))
|
||||
valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size,
|
||||
drop_remainder=True).map(normalize)
|
||||
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
verbose: bool = True) -> None:
|
||||
"""
|
||||
Train AAE for given data set.
|
||||
|
||||
:param dataset: train dataset
|
||||
:param iteration: identifier for the current training run
|
||||
:param result_prefix: prefix for result images
|
||||
:param channels: number of channels in input image (default: 1)
|
||||
:param zsize: size of the intermediary z (default: 32)
|
||||
:param lr: initial learning rate (default: 0.002)
|
||||
:param batch_size: the size of each batch (default: 128)
|
||||
:param train_epoch: number of epochs to train (default: 80)
|
||||
:param verbose: if True prints train progress info to console (default: True)
|
||||
"""
|
||||
|
||||
# get models
|
||||
encoder = Encoder(zsize)
|
||||
@ -226,7 +246,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
||||
step=global_step_decoder)
|
||||
from PIL import Image
|
||||
filename = 'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png'
|
||||
filename = os.path.join(result_prefix, 'reconstruction_' + str(epoch) + '.png')
|
||||
ndarr = grid.cpu().numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filename)
|
||||
@ -261,7 +281,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
||||
max_images=1, step=global_step_decoder)
|
||||
from PIL import Image
|
||||
filename = 'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png'
|
||||
filename = os.path.join(result_prefix, 'sample_' + str(epoch) + '.png')
|
||||
ndarr = grid.cpu().numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filename)
|
||||
@ -493,7 +513,9 @@ if __name__ == "__main__":
|
||||
tf.enable_eager_execution()
|
||||
inlier_classes = [0]
|
||||
iteration = 1
|
||||
train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=[0],
|
||||
total_classes=10)
|
||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||
train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10, iteration=iteration)
|
||||
train(dataset=train_dataset, iteration=iteration, result_prefix='results' + str(inlier_classes[0]) + '/')
|
||||
|
||||
Reference in New Issue
Block a user