Save predictions directly before observations are grouped

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-09-14 13:43:28 +02:00
parent 6a990b204a
commit 8230488031

View File

@ -322,17 +322,23 @@ def _predict_loop(generator: Generator,
saved_images_prediction = True saved_images_prediction = True
if use_bayesian: if use_bayesian:
decoded_predictions = callables.decode_func_dropout(predictions) decoded_predictions = callables.decode_func_dropout(predictions)
callables.save_func(decoded_predictions, original_labels, filenames,
batch_nr=batch_counter,
suffix="_prediction",
save_labels=False)
observations = callables.get_observations_func(decoded_predictions) observations = callables.get_observations_func(decoded_predictions)
callables.save_func(observations, original_labels, filenames, callables.save_func(observations, original_labels, filenames,
batch_nr=batch_counter, batch_nr=batch_counter,
suffix="_prediction") suffix="_observation",
save_labels=False)
for entropy_threshold in entropy_thresholds: for entropy_threshold in entropy_thresholds:
if use_bayesian: if use_bayesian:
decoded_predictions = callables.apply_entropy_threshold_func(observations, decoded_predictions = callables.apply_entropy_threshold_func(observations,
entropy_threshold=entropy_threshold) entropy_threshold=entropy_threshold)
callables.save_func(decoded_predictions, original_labels, filenames, callables.save_func(decoded_predictions, original_labels, filenames,
batch_nr=batch_counter, entropy_threshold=entropy_threshold, batch_nr=batch_counter, entropy_threshold=entropy_threshold,
suffix="_entropy") suffix="_entropy",
save_labels=False)
decoded_predictions = callables.apply_top_k_func(decoded_predictions) decoded_predictions = callables.apply_top_k_func(decoded_predictions)
else: else:
decoded_predictions = callables.decode_func(predictions, entropy_threshold=entropy_threshold) decoded_predictions = callables.decode_func(predictions, entropy_threshold=entropy_threshold)
@ -483,6 +489,7 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda
original_labels: np.ndarray, filenames: Sequence[str], original_labels: np.ndarray, filenames: Sequence[str],
output_file: str, label_output_file: str, output_file: str, label_output_file: str,
batch_nr: int, nr_digits: int, suffix: str, batch_nr: int, nr_digits: int, suffix: str,
save_labels: Optional[bool] = True,
entropy_threshold: Optional[float] = 0) -> None: entropy_threshold: Optional[float] = 0) -> None:
counter_str = str(batch_nr).zfill(nr_digits) counter_str = str(batch_nr).zfill(nr_digits)
@ -492,6 +499,7 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda
with open(filename, "wb") as file, open(label_filename, "wb") as label_file: with open(filename, "wb") as file, open(label_filename, "wb") as label_file:
pickle.dump(transformed_predictions, file) pickle.dump(transformed_predictions, file)
if save_labels:
pickle.dump({"labels": original_labels, "filenames": filenames}, label_file) pickle.dump({"labels": original_labels, "filenames": filenames}, label_file)