From 53022bf1a9a5c772235b77c3e66224dc43ef415d Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 13 Sep 2019 12:15:54 +0200 Subject: [PATCH] Added saving of predictions before entropy threshold Signed-off-by: Jim Martens --- src/twomartens/masterthesis/ssd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 0f3c7f8..46f8d97 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -32,6 +32,7 @@ from typing import List from typing import Optional from typing import Sequence from typing import Tuple +from typing import Union import math import numpy as np @@ -321,6 +322,9 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, observations = get_observations_func(decoded_predictions) for entropy_threshold in entropy_thresholds: if use_dropout: + save_func(observations, original_labels, filenames, + batch_nr=batch_counter, entropy_threshold=entropy_threshold, + suffix="_prediction") decoded_predictions = apply_entropy_threshold_func(observations, entropy_threshold=entropy_threshold) save_func(decoded_predictions, original_labels, filenames, @@ -474,7 +478,8 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: return inverse_transform_func(decoded_predictions, inverse_transforms) -def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str], +def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]], + original_labels: np.ndarray, filenames: Sequence[str], output_file: str, label_output_file: str, batch_nr: int, nr_digits: int, entropy_threshold: float, suffix: str) -> None: