Save predictions directly before observations are grouped
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -322,17 +322,23 @@ def _predict_loop(generator: Generator,
|
|||||||
saved_images_prediction = True
|
saved_images_prediction = True
|
||||||
if use_bayesian:
|
if use_bayesian:
|
||||||
decoded_predictions = callables.decode_func_dropout(predictions)
|
decoded_predictions = callables.decode_func_dropout(predictions)
|
||||||
|
callables.save_func(decoded_predictions, original_labels, filenames,
|
||||||
|
batch_nr=batch_counter,
|
||||||
|
suffix="_prediction",
|
||||||
|
save_labels=False)
|
||||||
observations = callables.get_observations_func(decoded_predictions)
|
observations = callables.get_observations_func(decoded_predictions)
|
||||||
callables.save_func(observations, original_labels, filenames,
|
callables.save_func(observations, original_labels, filenames,
|
||||||
batch_nr=batch_counter,
|
batch_nr=batch_counter,
|
||||||
suffix="_prediction")
|
suffix="_observation",
|
||||||
|
save_labels=False)
|
||||||
for entropy_threshold in entropy_thresholds:
|
for entropy_threshold in entropy_thresholds:
|
||||||
if use_bayesian:
|
if use_bayesian:
|
||||||
decoded_predictions = callables.apply_entropy_threshold_func(observations,
|
decoded_predictions = callables.apply_entropy_threshold_func(observations,
|
||||||
entropy_threshold=entropy_threshold)
|
entropy_threshold=entropy_threshold)
|
||||||
callables.save_func(decoded_predictions, original_labels, filenames,
|
callables.save_func(decoded_predictions, original_labels, filenames,
|
||||||
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
||||||
suffix="_entropy")
|
suffix="_entropy",
|
||||||
|
save_labels=False)
|
||||||
decoded_predictions = callables.apply_top_k_func(decoded_predictions)
|
decoded_predictions = callables.apply_top_k_func(decoded_predictions)
|
||||||
else:
|
else:
|
||||||
decoded_predictions = callables.decode_func(predictions, entropy_threshold=entropy_threshold)
|
decoded_predictions = callables.decode_func(predictions, entropy_threshold=entropy_threshold)
|
||||||
@ -483,6 +489,7 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda
|
|||||||
original_labels: np.ndarray, filenames: Sequence[str],
|
original_labels: np.ndarray, filenames: Sequence[str],
|
||||||
output_file: str, label_output_file: str,
|
output_file: str, label_output_file: str,
|
||||||
batch_nr: int, nr_digits: int, suffix: str,
|
batch_nr: int, nr_digits: int, suffix: str,
|
||||||
|
save_labels: Optional[bool] = True,
|
||||||
entropy_threshold: Optional[float] = 0) -> None:
|
entropy_threshold: Optional[float] = 0) -> None:
|
||||||
|
|
||||||
counter_str = str(batch_nr).zfill(nr_digits)
|
counter_str = str(batch_nr).zfill(nr_digits)
|
||||||
@ -492,7 +499,8 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda
|
|||||||
|
|
||||||
with open(filename, "wb") as file, open(label_filename, "wb") as label_file:
|
with open(filename, "wb") as file, open(label_filename, "wb") as label_file:
|
||||||
pickle.dump(transformed_predictions, file)
|
pickle.dump(transformed_predictions, file)
|
||||||
pickle.dump({"labels": original_labels, "filenames": filenames}, label_file)
|
if save_labels:
|
||||||
|
pickle.dump({"labels": original_labels, "filenames": filenames}, label_file)
|
||||||
|
|
||||||
|
|
||||||
def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray,
|
def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray,
|
||||||
|
|||||||
Reference in New Issue
Block a user