Optimised position of saving predictions after observations are grouped

Before the entropy threshold is applied, there is no difference
between the entropy thresholds. Therefore, the predictions should only
be saved once.

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-09-13 13:05:09 +02:00
parent d877da3ef3
commit 75f4e4cb1e

View File

@ -323,11 +323,11 @@ def _predict_loop(generator: Generator,
if use_bayesian: if use_bayesian:
decoded_predictions = callables.decode_func_dropout(predictions) decoded_predictions = callables.decode_func_dropout(predictions)
observations = callables.get_observations_func(decoded_predictions) observations = callables.get_observations_func(decoded_predictions)
callables.save_func(observations, original_labels, filenames,
batch_nr=batch_counter,
suffix="_prediction")
for entropy_threshold in entropy_thresholds: for entropy_threshold in entropy_thresholds:
if use_bayesian: if use_bayesian:
callables.save_func(observations, original_labels, filenames,
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
suffix="_prediction")
decoded_predictions = callables.apply_entropy_threshold_func(observations, decoded_predictions = callables.apply_entropy_threshold_func(observations,
entropy_threshold=entropy_threshold) entropy_threshold=entropy_threshold)
callables.save_func(decoded_predictions, original_labels, filenames, callables.save_func(decoded_predictions, original_labels, filenames,
@ -482,8 +482,8 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms:
def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]], def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]],
original_labels: np.ndarray, filenames: Sequence[str], original_labels: np.ndarray, filenames: Sequence[str],
output_file: str, label_output_file: str, output_file: str, label_output_file: str,
batch_nr: int, nr_digits: int, entropy_threshold: float, batch_nr: int, nr_digits: int, suffix: str,
suffix: str) -> None: entropy_threshold: Optional[float] = 0) -> None:
counter_str = str(batch_nr).zfill(nr_digits) counter_str = str(batch_nr).zfill(nr_digits)
filename = f"{output_file}{suffix}-{counter_str}" filename = f"{output_file}{suffix}-{counter_str}"