diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 966177b..71a11cd 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -38,8 +38,11 @@ import os import pickle import time from typing import Dict +from typing import List from typing import Optional +from typing import Union +import numpy as np import tensorflow as tf from tensorflow.python.ops import summary_ops_v2 @@ -168,7 +171,7 @@ def predict(dataset: tf.data.Dataset, def _predict_one_epoch(dataset: tf.data.Dataset, use_dropout: bool, forward_passes_per_image: int, - ssd: tf.keras.Model) -> Dict[str, float]: + ssd: tf.keras.Model) -> Dict[str, Union[float, List[List[np.ndarray]]]]: with summary_ops_v2.always_record_summaries(): epoch_start_time = time.time()