Moved dataset preparation into own function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -38,33 +38,27 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
|
|||||||
GRACE = 10
|
GRACE = 10
|
||||||
|
|
||||||
|
|
||||||
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||||
iteration: int,
|
fold_prefix: str = 'data/data_fold_',
|
||||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
batch_size: int = 128, folds: int = 5) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
|
||||||
batch_size: int = 128, train_epoch: int = 80,
|
|
||||||
folds: int = 5, verbose: bool = True):
|
|
||||||
"""
|
"""
|
||||||
Train AAE for mnist data set.
|
Prepares the MNIST data
|
||||||
|
|
||||||
:param folding_id: id of fold used for test data
|
:param test_fold_id: id of test fold
|
||||||
:param inlier_classes: list of class ids that are considered inliers
|
:param inlier_classes: list of class ids that are considered inliers
|
||||||
:param total_classes: total number of classes
|
:param total_classes: total number of classes
|
||||||
:param iteration: identifier for the current training run
|
:param fold_prefix: the prefix for the fold pickle files (default: 'data/data_fold_')
|
||||||
:param channels: number of channels in input image
|
:param batch_size: size of batch (default: 128)
|
||||||
:param zsize: size of the intermediary z
|
:param folds: number of folds (default: 5)
|
||||||
:param lr: learning rate
|
:return: tuple(train dataset, valid dataset)
|
||||||
:param batch_size: size of each batch
|
|
||||||
:param train_epoch: number of epochs to train
|
|
||||||
:param folds: number of folds available
|
|
||||||
:param verbose: if True prints train progress info to console
|
|
||||||
"""
|
"""
|
||||||
# prepare data
|
# prepare data
|
||||||
mnist_train = []
|
mnist_train = []
|
||||||
mnist_valid = []
|
mnist_valid = []
|
||||||
|
|
||||||
for i in range(folds):
|
for i in range(folds):
|
||||||
if i != folding_id: # exclude testing fold, representing 20% of each class
|
if i != test_fold_id: # exclude testing fold, representing 20% of each class
|
||||||
with open('data/data_fold_%d.pkl' % i, 'rb') as pkl:
|
with open(f"{fold_prefix}{i:d}.pkl", 'rb') as pkl:
|
||||||
fold = pickle.load(pkl)
|
fold = pickle.load(pkl)
|
||||||
if len(mnist_valid) == 0: # single out one fold, comprising 20% of each class
|
if len(mnist_valid) == 0: # single out one fold, comprising 20% of each class
|
||||||
mnist_valid = fold
|
mnist_valid = fold
|
||||||
@ -82,17 +76,43 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
def list_of_pairs_to_numpy(list_of_pairs: Sequence[Tuple[int, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Converts a list of pairs to a numpy array.
|
Converts a list of pairs to a numpy array.
|
||||||
|
|
||||||
:param list_of_pairs: list of pairs
|
:param list_of_pairs: list of pairs
|
||||||
:return: numpy array
|
:return: numpy array
|
||||||
"""
|
"""
|
||||||
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
|
return np.asarray([x[1] for x in list_of_pairs], np.float32), np.asarray([x[0] for x in list_of_pairs], np.int)
|
||||||
|
|
||||||
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
mnist_train_x, mnist_train_y = list_of_pairs_to_numpy(mnist_train)
|
||||||
|
mnist_valid_x, mnist_valid_y = list_of_pairs_to_numpy(mnist_valid)
|
||||||
|
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
train_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||||
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize)
|
train_dataset = train_dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size,
|
||||||
|
drop_remainder=True).map(normalize)
|
||||||
|
valid_dataset = tf.data.Dataset.from_tensor_slices((mnist_valid_x, mnist_valid_y))
|
||||||
|
valid_dataset = valid_dataset.shuffle(mnist_valid_x.shape[0]).batch(batch_size,
|
||||||
|
drop_remainder=True).map(normalize)
|
||||||
|
|
||||||
|
return train_dataset, valid_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
||||||
|
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||||
|
batch_size: int = 128, train_epoch: int = 80,
|
||||||
|
verbose: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Train AAE for given data set.
|
||||||
|
|
||||||
|
:param dataset: train dataset
|
||||||
|
:param iteration: identifier for the current training run
|
||||||
|
:param result_prefix: prefix for result images
|
||||||
|
:param channels: number of channels in input image (default: 1)
|
||||||
|
:param zsize: size of the intermediary z (default: 32)
|
||||||
|
:param lr: initial learning rate (default: 0.002)
|
||||||
|
:param batch_size: the size of each batch (default: 128)
|
||||||
|
:param train_epoch: number of epochs to train (default: 80)
|
||||||
|
:param verbose: if True prints train progress info to console (default: True)
|
||||||
|
"""
|
||||||
|
|
||||||
# get models
|
# get models
|
||||||
encoder = Encoder(zsize)
|
encoder = Encoder(zsize)
|
||||||
@ -226,7 +246,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
||||||
step=global_step_decoder)
|
step=global_step_decoder)
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
filename = 'results' + str(inlier_classes[0]) + '/reconstruction_' + str(epoch) + '.png'
|
filename = os.path.join(result_prefix, 'reconstruction_' + str(epoch) + '.png')
|
||||||
ndarr = grid.cpu().numpy()
|
ndarr = grid.cpu().numpy()
|
||||||
im = Image.fromarray(ndarr)
|
im = Image.fromarray(ndarr)
|
||||||
im.save(filename)
|
im.save(filename)
|
||||||
@ -261,7 +281,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
||||||
max_images=1, step=global_step_decoder)
|
max_images=1, step=global_step_decoder)
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
filename = 'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png'
|
filename = os.path.join(result_prefix, 'sample_' + str(epoch) + '.png')
|
||||||
ndarr = grid.cpu().numpy()
|
ndarr = grid.cpu().numpy()
|
||||||
im = Image.fromarray(ndarr)
|
im = Image.fromarray(ndarr)
|
||||||
im.save(filename)
|
im.save(filename)
|
||||||
@ -493,7 +513,9 @@ if __name__ == "__main__":
|
|||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
inlier_classes = [0]
|
inlier_classes = [0]
|
||||||
iteration = 1
|
iteration = 1
|
||||||
|
train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=[0],
|
||||||
|
total_classes=10)
|
||||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||||
train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10, iteration=iteration)
|
train(dataset=train_dataset, iteration=iteration, result_prefix='results' + str(inlier_classes[0]) + '/')
|
||||||
|
|||||||
Reference in New Issue
Block a user