diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 0017891..4c4295f 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -322,17 +322,23 @@ def _predict_loop(generator: Generator, saved_images_prediction = True if use_bayesian: decoded_predictions = callables.decode_func_dropout(predictions) + callables.save_func(decoded_predictions, original_labels, filenames, + batch_nr=batch_counter, + suffix="_prediction", + save_labels=False) observations = callables.get_observations_func(decoded_predictions) callables.save_func(observations, original_labels, filenames, batch_nr=batch_counter, - suffix="_prediction") + suffix="_observation", + save_labels=False) for entropy_threshold in entropy_thresholds: if use_bayesian: decoded_predictions = callables.apply_entropy_threshold_func(observations, entropy_threshold=entropy_threshold) callables.save_func(decoded_predictions, original_labels, filenames, batch_nr=batch_counter, entropy_threshold=entropy_threshold, - suffix="_entropy") + suffix="_entropy", + save_labels=False) decoded_predictions = callables.apply_top_k_func(decoded_predictions) else: decoded_predictions = callables.decode_func(predictions, entropy_threshold=entropy_threshold) @@ -483,6 +489,7 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda original_labels: np.ndarray, filenames: Sequence[str], output_file: str, label_output_file: str, batch_nr: int, nr_digits: int, suffix: str, + save_labels: Optional[bool] = True, entropy_threshold: Optional[float] = 0) -> None: counter_str = str(batch_nr).zfill(nr_digits) @@ -492,7 +499,8 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda with open(filename, "wb") as file, open(label_filename, "wb") as label_file: pickle.dump(transformed_predictions, file) - pickle.dump({"labels": original_labels, "filenames": filenames}, label_file) + if save_labels: + pickle.dump({"labels": original_labels, "filenames": filenames}, label_file) def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray,