diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 1c935a9..0017891 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -323,11 +323,11 @@ def _predict_loop(generator: Generator, if use_bayesian: decoded_predictions = callables.decode_func_dropout(predictions) observations = callables.get_observations_func(decoded_predictions) + callables.save_func(observations, original_labels, filenames, + batch_nr=batch_counter, + suffix="_prediction") for entropy_threshold in entropy_thresholds: if use_bayesian: - callables.save_func(observations, original_labels, filenames, - batch_nr=batch_counter, entropy_threshold=entropy_threshold, - suffix="_prediction") decoded_predictions = callables.apply_entropy_threshold_func(observations, entropy_threshold=entropy_threshold) callables.save_func(decoded_predictions, original_labels, filenames, @@ -482,8 +482,8 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]], original_labels: np.ndarray, filenames: Sequence[str], output_file: str, label_output_file: str, - batch_nr: int, nr_digits: int, entropy_threshold: float, - suffix: str) -> None: + batch_nr: int, nr_digits: int, suffix: str, + entropy_threshold: Optional[float] = 0) -> None: counter_str = str(batch_nr).zfill(nr_digits) filename = f"{output_file}{suffix}-{counter_str}"