diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 64b3da1..39acc75 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -400,4 +400,6 @@ def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable: if __name__ == "__main__": tf.enable_eager_execution() - train_mnist(folding_id=0, inlier_classes=[0], total_classes=10) + train_summary_writer = summary_ops_v2.create_file_writer('./summaries/train') + with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): + train_mnist(folding_id=0, inlier_classes=[0], total_classes=10)