Added TODOs to train module
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -30,6 +30,10 @@ Functions:
|
|||||||
prepare_training_data(...): prepares the mnist training data
|
prepare_training_data(...): prepares the mnist training data
|
||||||
train(...): trains the AAE models
|
train(...): trains the AAE models
|
||||||
|
|
||||||
|
Todos:
|
||||||
|
- fix early stopping
|
||||||
|
- fix losses reaching exactly zero
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
@ -621,7 +625,7 @@ def _normalize(feature: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tens
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
inlier_classes = [0]
|
inlier_classes = [3]
|
||||||
iteration = 1
|
iteration = 1
|
||||||
train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=inlier_classes,
|
train_dataset, _ = prepare_training_data(test_fold_id=0, inlier_classes=inlier_classes,
|
||||||
total_classes=10)
|
total_classes=10)
|
||||||
|
|||||||
Reference in New Issue
Block a user