Optimised position of saving predictions after observations are grouped
Before the entropy threshold is applied, there is no difference between the entropy thresholds. Therefore, the predictions should only be saved once. Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -323,11 +323,11 @@ def _predict_loop(generator: Generator,
|
|||||||
if use_bayesian:
|
if use_bayesian:
|
||||||
decoded_predictions = callables.decode_func_dropout(predictions)
|
decoded_predictions = callables.decode_func_dropout(predictions)
|
||||||
observations = callables.get_observations_func(decoded_predictions)
|
observations = callables.get_observations_func(decoded_predictions)
|
||||||
|
callables.save_func(observations, original_labels, filenames,
|
||||||
|
batch_nr=batch_counter,
|
||||||
|
suffix="_prediction")
|
||||||
for entropy_threshold in entropy_thresholds:
|
for entropy_threshold in entropy_thresholds:
|
||||||
if use_bayesian:
|
if use_bayesian:
|
||||||
callables.save_func(observations, original_labels, filenames,
|
|
||||||
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
|
||||||
suffix="_prediction")
|
|
||||||
decoded_predictions = callables.apply_entropy_threshold_func(observations,
|
decoded_predictions = callables.apply_entropy_threshold_func(observations,
|
||||||
entropy_threshold=entropy_threshold)
|
entropy_threshold=entropy_threshold)
|
||||||
callables.save_func(decoded_predictions, original_labels, filenames,
|
callables.save_func(decoded_predictions, original_labels, filenames,
|
||||||
@ -482,8 +482,8 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms:
|
|||||||
def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]],
|
def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]],
|
||||||
original_labels: np.ndarray, filenames: Sequence[str],
|
original_labels: np.ndarray, filenames: Sequence[str],
|
||||||
output_file: str, label_output_file: str,
|
output_file: str, label_output_file: str,
|
||||||
batch_nr: int, nr_digits: int, entropy_threshold: float,
|
batch_nr: int, nr_digits: int, suffix: str,
|
||||||
suffix: str) -> None:
|
entropy_threshold: Optional[float] = 0) -> None:
|
||||||
|
|
||||||
counter_str = str(batch_nr).zfill(nr_digits)
|
counter_str = str(batch_nr).zfill(nr_digits)
|
||||||
filename = f"{output_file}{suffix}-{counter_str}"
|
filename = f"{output_file}{suffix}-{counter_str}"
|
||||||
|
|||||||
Reference in New Issue
Block a user