Removed caching of dataset

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 11:37:05 +01:00
parent 85d743c058
commit 5ba3bc552f

View File

@ -87,7 +87,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
# get dataset # get dataset
dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y)) dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize).cache() dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize)
# get models # get models
encoder = Encoder(zsize) encoder = Encoder(zsize)