Implemented support for predicting with range of entropy thresholds
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
parent
f270e0add1
commit
28a0d35d36
|
@ -264,7 +264,8 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||
|
||||
batch_size, image_size, learning_rate, \
|
||||
forward_passes_per_image, nr_classes, iou_threshold, dropout_rate, \
|
||||
entropy_threshold, top_k, nr_trajectories, \
|
||||
use_entropy_threshold, entropy_threshold_min, entropy_threshold_max, \
|
||||
top_k, nr_trajectories, \
|
||||
coco_path, output_path, weights_path, ground_truth_path = _ssd_test_get_config_values(args, conf.get_property)
|
||||
|
||||
use_dropout = _ssd_is_dropout(args)
|
||||
|
@ -305,7 +306,9 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||
image_size,
|
||||
batch_size,
|
||||
forward_passes_per_image,
|
||||
entropy_threshold,
|
||||
use_entropy_threshold,
|
||||
entropy_threshold_min,
|
||||
entropy_threshold_max,
|
||||
output_path,
|
||||
coco_path,
|
||||
use_dropout,
|
||||
|
@ -502,7 +505,9 @@ def _ssd_train_get_config_values(config_get: Callable[[str], Union[str, float, i
|
|||
|
||||
def _ssd_test_get_config_values(args: argparse.Namespace,
|
||||
config_get: Callable[[str], Union[str, float, int, bool]]
|
||||
) -> Tuple[int, int, float, int, int, float, float, float, int, int,
|
||||
) -> Tuple[int, int, float, int, int, float, float,
|
||||
bool, float, float,
|
||||
int, int,
|
||||
str, str, str, str]:
|
||||
|
||||
batch_size = config_get("Parameters.batch_size")
|
||||
|
@ -512,7 +517,9 @@ def _ssd_test_get_config_values(args: argparse.Namespace,
|
|||
nr_classes = config_get("Parameters.nr_classes")
|
||||
iou_threshold = config_get("Parameters.ssd_iou_threshold")
|
||||
dropout_rate = config_get("Parameters.ssd_dropout_rate")
|
||||
entropy_threshold = config_get("Parameters.ssd_entropy_threshold")
|
||||
use_entropy_threshold = config_get("Parameters.ssd_use_entropy_threshold")
|
||||
entropy_threshold_min = config_get("Parameters.ssd_entropy_threshold_min")
|
||||
entropy_threshold_max = config_get("Parameters.ssd_entropy_threshold_max")
|
||||
top_k = config_get("Parameters.ssd_top_k")
|
||||
nr_trajectories = config_get("Parameters.nr_trajectories")
|
||||
|
||||
|
@ -532,7 +539,11 @@ def _ssd_test_get_config_values(args: argparse.Namespace,
|
|||
nr_classes,
|
||||
iou_threshold,
|
||||
dropout_rate,
|
||||
entropy_threshold,
|
||||
#
|
||||
use_entropy_threshold,
|
||||
entropy_threshold_min,
|
||||
entropy_threshold_max,
|
||||
#
|
||||
top_k,
|
||||
nr_trajectories,
|
||||
#
|
||||
|
|
|
@ -59,7 +59,9 @@ _CONFIG_PROPS = {
|
|||
"ssd_iou_threshold": (float, "0.45"),
|
||||
"ssd_top_k": (int, "200"),
|
||||
"ssd_dropout_rate": (float, "0.5"),
|
||||
"ssd_entropy_threshold": (float, "2.5"),
|
||||
"ssd_use_entropy_threshold": (bool, False),
|
||||
"ssd_entropy_threshold_min": (float, "0.1"),
|
||||
"ssd_entropy_threshold_max": (float, "2.5"),
|
||||
"nr_trajectories": (int, "-1")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import pickle
|
|||
from typing import List, Sequence, Tuple, Generator
|
||||
from typing import Optional
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -145,7 +146,9 @@ def predict(generator: callable,
|
|||
image_size: int,
|
||||
batch_size: int,
|
||||
forward_passes_per_image: int,
|
||||
entropy_threshold: float,
|
||||
use_entropy_threshold: bool,
|
||||
entropy_threshold_min: float,
|
||||
entropy_threshold_max: float,
|
||||
output_path: str,
|
||||
coco_path: str,
|
||||
use_dropout: bool,
|
||||
|
@ -163,7 +166,9 @@ def predict(generator: callable,
|
|||
batch_size: number of items in every batch
|
||||
forward_passes_per_image: specifies number of forward passes per image
|
||||
used by DropoutSSD
|
||||
entropy_threshold: specifies the threshold for the entropy
|
||||
use_entropy_threshold: if True entropy thresholding is applied
|
||||
entropy_threshold_min: specifies the minimum threshold for the entropy
|
||||
entropy_threshold_max: specifies the maximum threshold for the entropy
|
||||
output_path: the path in which the results should be saved
|
||||
coco_path: the path to the COCO data set
|
||||
use_dropout: if True, multiple forward passes and observations will be used
|
||||
|
@ -188,7 +193,6 @@ def predict(generator: callable,
|
|||
_decode_predictions,
|
||||
decode_func=ssd_output_decoder.decode_detections_fast,
|
||||
image_size=image_size,
|
||||
entropy_threshold=entropy_threshold
|
||||
),
|
||||
transform_func=functools.partial(
|
||||
_transform_predictions,
|
||||
|
@ -196,7 +200,10 @@ def predict(generator: callable,
|
|||
save_func=functools.partial(_save_predictions,
|
||||
output_file=output_file,
|
||||
label_output_file=label_output_file,
|
||||
nr_digits=nr_digits))
|
||||
nr_digits=nr_digits),
|
||||
use_entropy_threshold=use_entropy_threshold,
|
||||
entropy_threshold_min=entropy_threshold_min,
|
||||
entropy_threshold_max=entropy_threshold_max)
|
||||
|
||||
|
||||
def train(train_generator: callable,
|
||||
|
@ -271,26 +278,40 @@ def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, st
|
|||
def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
||||
dropout_step: callable, vanilla_step: callable,
|
||||
save_images: callable, decode_func: callable,
|
||||
transform_func: callable, save_func: callable) -> None:
|
||||
transform_func: callable, save_func: callable,
|
||||
use_entropy_threshold: bool, entropy_threshold_min: float,
|
||||
entropy_threshold_max: float) -> None:
|
||||
|
||||
batch_counter = 0
|
||||
saved_images = False
|
||||
saved_images_prediction = False
|
||||
saved_images_decoding = False
|
||||
if use_entropy_threshold:
|
||||
nr_steps = math.floor((entropy_threshold_max - entropy_threshold_min) * 10)
|
||||
entropy_thresholds = [i / 10 + entropy_threshold_min for i in range(nr_steps)]
|
||||
else:
|
||||
entropy_thresholds = [0]
|
||||
|
||||
for inputs, filenames, inverse_transforms, original_labels in generator:
|
||||
if use_dropout:
|
||||
predictions = dropout_step(inputs)
|
||||
else:
|
||||
predictions = vanilla_step(inputs)
|
||||
|
||||
if not saved_images:
|
||||
if not saved_images_prediction:
|
||||
save_images(inputs, predictions, custom_string="after-prediction")
|
||||
decoded_predictions = decode_func(predictions)
|
||||
if not saved_images:
|
||||
save_images(inputs, decoded_predictions, custom_string="after-decoding")
|
||||
saved_images = True
|
||||
transformed_predictions = transform_func(decoded_predictions,
|
||||
inverse_transforms)
|
||||
save_func(transformed_predictions, original_labels, filenames,
|
||||
batch_nr=batch_counter)
|
||||
for entropy_threshold in entropy_thresholds:
|
||||
decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold)
|
||||
if not saved_images_decoding:
|
||||
custom_string = f"after-decoding-{entropy_threshold}" if use_entropy_threshold else "after-decoding"
|
||||
save_images(inputs, decoded_predictions, custom_string=custom_string)
|
||||
|
||||
transformed_predictions = transform_func(decoded_predictions,
|
||||
inverse_transforms)
|
||||
save_func(transformed_predictions, original_labels, filenames,
|
||||
batch_nr=batch_counter, entropy_threshold=entropy_threshold)
|
||||
|
||||
if not saved_images_decoding:
|
||||
saved_images_decoding = True
|
||||
|
||||
batch_counter += 1
|
||||
|
||||
|
@ -341,10 +362,11 @@ def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms:
|
|||
|
||||
def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str],
|
||||
output_file: str, label_output_file: str,
|
||||
batch_nr: int, nr_digits: int) -> None:
|
||||
batch_nr: int, nr_digits: int, entropy_threshold: float) -> None:
|
||||
|
||||
counter_str = str(batch_nr).zfill(nr_digits)
|
||||
filename = f"{output_file}-{counter_str}.bin"
|
||||
filename = f"{output_file}-{counter_str}"
|
||||
filename = f"{filename}-{entropy_threshold}" if entropy_threshold else filename
|
||||
label_filename = f"{label_output_file}-{counter_str}.bin"
|
||||
|
||||
with open(filename, "wb") as file, open(label_filename, "wb") as label_file:
|
||||
|
|
Loading…
Reference in New Issue