Added saving of predictions before entropy threshold
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -32,6 +32,7 @@ from typing import List
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -321,6 +322,9 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
|||||||
observations = get_observations_func(decoded_predictions)
|
observations = get_observations_func(decoded_predictions)
|
||||||
for entropy_threshold in entropy_thresholds:
|
for entropy_threshold in entropy_thresholds:
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
|
save_func(observations, original_labels, filenames,
|
||||||
|
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
||||||
|
suffix="_prediction")
|
||||||
decoded_predictions = apply_entropy_threshold_func(observations,
|
decoded_predictions = apply_entropy_threshold_func(observations,
|
||||||
entropy_threshold=entropy_threshold)
|
entropy_threshold=entropy_threshold)
|
||||||
save_func(decoded_predictions, original_labels, filenames,
|
save_func(decoded_predictions, original_labels, filenames,
|
||||||
@ -474,7 +478,8 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms:
|
|||||||
return inverse_transform_func(decoded_predictions, inverse_transforms)
|
return inverse_transform_func(decoded_predictions, inverse_transforms)
|
||||||
|
|
||||||
|
|
||||||
def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str],
|
def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.ndarray]],
|
||||||
|
original_labels: np.ndarray, filenames: Sequence[str],
|
||||||
output_file: str, label_output_file: str,
|
output_file: str, label_output_file: str,
|
||||||
batch_nr: int, nr_digits: int, entropy_threshold: float,
|
batch_nr: int, nr_digits: int, entropy_threshold: float,
|
||||||
suffix: str) -> None:
|
suffix: str) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user