diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index ef203e4..1d2e7b3 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -351,7 +351,8 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, transformed_predictions = transform_func(decoded_predictions, inverse_transforms) save_func(transformed_predictions, original_labels, filenames, - batch_nr=batch_counter, entropy_threshold=entropy_threshold) + batch_nr=batch_counter, entropy_threshold=entropy_threshold, + suffix="_transformed") if not saved_images_decoding: saved_images_decoding = True @@ -486,10 +487,11 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str], output_file: str, label_output_file: str, - batch_nr: int, nr_digits: int, entropy_threshold: float) -> None: + batch_nr: int, nr_digits: int, entropy_threshold: float, + suffix: str) -> None: counter_str = str(batch_nr).zfill(nr_digits) - filename = f"{output_file}-{counter_str}" + filename = f"{output_file}{suffix}-{counter_str}" filename = f"{filename}-{entropy_threshold}" if entropy_threshold else filename label_filename = f"{label_output_file}-{counter_str}.bin"