Added saving of predictions before entropy threshold

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-09-13 12:15:54 +02:00
parent 681d23f345
commit 53022bf1a9

View File

@ -32,6 +32,7 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import Union
import math import math
import numpy as np import numpy as np
@ -321,6 +322,9 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
observations = get_observations_func(decoded_predictions) observations = get_observations_func(decoded_predictions)
for entropy_threshold in entropy_thresholds: for entropy_threshold in entropy_thresholds:
if use_dropout: if use_dropout:
save_func(observations, original_labels, filenames,
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
suffix="_prediction")
decoded_predictions = apply_entropy_threshold_func(observations, decoded_predictions = apply_entropy_threshold_func(observations,
entropy_threshold=entropy_threshold) entropy_threshold=entropy_threshold)
save_func(decoded_predictions, original_labels, filenames, save_func(decoded_predictions, original_labels, filenames,
@ -474,7 +478,8 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms:
return inverse_transform_func(decoded_predictions, inverse_transforms) return inverse_transform_func(decoded_predictions, inverse_transforms)
def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str], def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]],
original_labels: np.ndarray, filenames: Sequence[str],
output_file: str, label_output_file: str, output_file: str, label_output_file: str,
batch_nr: int, nr_digits: int, entropy_threshold: float, batch_nr: int, nr_digits: int, entropy_threshold: float,
suffix: str) -> None: suffix: str) -> None: