Removed caching of dataset
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -87,7 +87,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
# get dataset
|
||||
dataset = tf.data.Dataset.from_tensor_slices((mnist_train_x, mnist_train_y))
|
||||
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize).cache()
|
||||
dataset = dataset.shuffle(mnist_train_x.shape[0]).batch(batch_size, drop_remainder=True).map(normalize)
|
||||
|
||||
# get models
|
||||
encoder = Encoder(zsize)
|
||||
|
||||
Reference in New Issue
Block a user