Improved readability of code
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
parent
01f43651b5
commit
d877da3ef3
|
@ -37,6 +37,7 @@ from typing import Union
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from attributedict.collections import AttributeDict
|
||||||
|
|
||||||
from twomartens.masterthesis import config
|
from twomartens.masterthesis import config
|
||||||
from twomartens.masterthesis import debug
|
from twomartens.masterthesis import debug
|
||||||
|
@ -116,6 +117,7 @@ def get_model(use_bayesian: bool,
|
||||||
|
|
||||||
|
|
||||||
def get_loss_func() -> callable:
|
def get_loss_func() -> callable:
|
||||||
|
"""Returns loss function for SSD."""
|
||||||
return keras_ssd_loss.SSDLoss().compute_loss
|
return keras_ssd_loss.SSDLoss().compute_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,21 +166,27 @@ def predict(generator: callable,
|
||||||
"""
|
"""
|
||||||
output_file, label_output_file = _predict_prepare_paths(output_path, use_bayesian)
|
output_file, label_output_file = _predict_prepare_paths(output_path, use_bayesian)
|
||||||
|
|
||||||
_predict_loop(generator, use_bayesian, steps_per_epoch,
|
_predict_loop(
|
||||||
dropout_step=functools.partial(
|
generator=generator,
|
||||||
|
use_bayesian=use_bayesian,
|
||||||
|
conf_obj=conf_obj,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
callables=AttributeDict({
|
||||||
|
"dropout_step": functools.partial(
|
||||||
_predict_dropout_step,
|
_predict_dropout_step,
|
||||||
model=model,
|
model=model,
|
||||||
batch_size=conf_obj.parameters.batch_size,
|
batch_size=conf_obj.parameters.batch_size,
|
||||||
forward_passes_per_image=conf_obj.parameters.ssd_forward_passes_per_image
|
forward_passes_per_image=conf_obj.parameters.ssd_forward_passes_per_image
|
||||||
),
|
),
|
||||||
vanilla_step=functools.partial(_predict_vanilla_step, model=model),
|
"vanilla_step": functools.partial(_predict_vanilla_step, model=model),
|
||||||
save_images=functools.partial(_predict_save_images,
|
"save_images": functools.partial(
|
||||||
|
_predict_save_images,
|
||||||
save_images=debug.save_ssd_train_images,
|
save_images=debug.save_ssd_train_images,
|
||||||
get_coco_cat_maps_func=coco_utils.get_coco_category_maps,
|
get_coco_cat_maps_func=coco_utils.get_coco_category_maps,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
coco_path=conf_obj.paths.coco,
|
conf_obj=conf_obj
|
||||||
image_size=conf_obj.parameters.ssd_image_size),
|
),
|
||||||
decode_func=functools.partial(
|
"decode_func": functools.partial(
|
||||||
_decode_predictions,
|
_decode_predictions,
|
||||||
decode_func=ssd_output_decoder.decode_detections,
|
decode_func=ssd_output_decoder.decode_detections,
|
||||||
image_size=conf_obj.parameters.ssd_image_size,
|
image_size=conf_obj.parameters.ssd_image_size,
|
||||||
|
@ -186,34 +194,35 @@ def predict(generator: callable,
|
||||||
iou_threshold=conf_obj.parameters.ssd_iou_threshold,
|
iou_threshold=conf_obj.parameters.ssd_iou_threshold,
|
||||||
top_k=conf_obj.parameters.ssd_top_k
|
top_k=conf_obj.parameters.ssd_top_k
|
||||||
),
|
),
|
||||||
decode_func_dropout=functools.partial(
|
"decode_func_dropout": functools.partial(
|
||||||
_decode_predictions_dropout,
|
_decode_predictions_dropout,
|
||||||
decode_func=ssd_output_decoder.decode_detections_dropout,
|
decode_func=ssd_output_decoder.decode_detections_dropout,
|
||||||
image_size=conf_obj.parameters.ssd_image_size,
|
image_size=conf_obj.parameters.ssd_image_size,
|
||||||
confidence_threshold=conf_obj.parameters.ssd_confidence_threshold,
|
confidence_threshold=conf_obj.parameters.ssd_confidence_threshold,
|
||||||
),
|
),
|
||||||
apply_entropy_threshold_func=functools.partial(
|
"apply_entropy_threshold_func": functools.partial(
|
||||||
_apply_entropy_filtering,
|
_apply_entropy_filtering,
|
||||||
confidence_threshold=conf_obj.parameters.ssd_confidence_threshold,
|
confidence_threshold=conf_obj.parameters.ssd_confidence_threshold,
|
||||||
nr_classes=conf_obj.parameters.nr_classes,
|
nr_classes=conf_obj.parameters.nr_classes,
|
||||||
iou_threshold=conf_obj.parameters.ssd_iou_threshold,
|
iou_threshold=conf_obj.parameters.ssd_iou_threshold,
|
||||||
use_nms=conf_obj.parameters.ssd_use_nms
|
use_nms=conf_obj.parameters.ssd_use_nms
|
||||||
),
|
),
|
||||||
apply_top_k_func=functools.partial(
|
"apply_top_k_func": functools.partial(
|
||||||
_apply_top_k,
|
_apply_top_k,
|
||||||
top_k=conf_obj.parameters.ssd_top_k
|
top_k=conf_obj.parameters.ssd_top_k
|
||||||
),
|
),
|
||||||
get_observations_func=_get_observations,
|
"get_observations_func": _get_observations,
|
||||||
transform_func=functools.partial(
|
"transform_func": functools.partial(
|
||||||
_transform_predictions,
|
_transform_predictions,
|
||||||
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms),
|
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms),
|
||||||
save_func=functools.partial(_save_predictions,
|
"save_func": functools.partial(
|
||||||
|
_save_predictions,
|
||||||
output_file=output_file,
|
output_file=output_file,
|
||||||
label_output_file=label_output_file,
|
label_output_file=label_output_file,
|
||||||
nr_digits=nr_digits),
|
nr_digits=nr_digits
|
||||||
use_entropy_threshold=conf_obj.parameters.ssd_use_entropy_threshold,
|
)
|
||||||
entropy_threshold_min=conf_obj.parameters.ssd_entropy_threshold_min,
|
})
|
||||||
entropy_threshold_max=conf_obj.parameters.ssd_entropy_threshold_max)
|
)
|
||||||
|
|
||||||
|
|
||||||
def train(train_generator: callable,
|
def train(train_generator: callable,
|
||||||
|
@ -285,56 +294,56 @@ def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, st
|
||||||
return output_file, label_output_file
|
return output_file, label_output_file
|
||||||
|
|
||||||
|
|
||||||
def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
def _predict_loop(generator: Generator,
|
||||||
dropout_step: callable, vanilla_step: callable,
|
use_bayesian: bool,
|
||||||
save_images: callable, decode_func: callable,
|
conf_obj: config.Config,
|
||||||
decode_func_dropout: callable, get_observations_func: callable,
|
steps_per_epoch: int,
|
||||||
apply_entropy_threshold_func: callable, apply_top_k_func: callable,
|
callables: AttributeDict) -> None:
|
||||||
transform_func: callable, save_func: callable,
|
|
||||||
use_entropy_threshold: bool, entropy_threshold_min: float,
|
|
||||||
entropy_threshold_max: float) -> None:
|
|
||||||
|
|
||||||
batch_counter = 0
|
batch_counter = 0
|
||||||
saved_images_prediction = False
|
saved_images_prediction = False
|
||||||
saved_images_decoding = False
|
saved_images_decoding = False
|
||||||
if use_entropy_threshold:
|
if conf_obj.parameters.ssd_use_entropy_threshold:
|
||||||
nr_steps = math.floor((entropy_threshold_max - entropy_threshold_min) * 10)
|
nr_steps = math.floor(
|
||||||
entropy_thresholds = [round(i / 10 + entropy_threshold_min, 1) for i in range(nr_steps)]
|
(conf_obj.parameters.ssd_entropy_threshold_max - conf_obj.parameters.ssd_entropy_threshold_min) * 10
|
||||||
|
)
|
||||||
|
entropy_thresholds = [round(i / 10 + conf_obj.parameters.ssd_entropy_threshold_min, 1) for i in range(nr_steps)]
|
||||||
else:
|
else:
|
||||||
entropy_thresholds = [0]
|
entropy_thresholds = [0]
|
||||||
|
|
||||||
for inputs, filenames, inverse_transforms, original_labels in generator:
|
for inputs, filenames, inverse_transforms, original_labels in generator:
|
||||||
if use_dropout:
|
if use_bayesian:
|
||||||
predictions = dropout_step(inputs)
|
predictions = callables.dropout_step(inputs)
|
||||||
else:
|
else:
|
||||||
predictions = vanilla_step(inputs)
|
predictions = callables.vanilla_step(inputs)
|
||||||
|
|
||||||
if not saved_images_prediction:
|
if not saved_images_prediction:
|
||||||
save_images(inputs, predictions, custom_string="after-prediction")
|
callables.save_images(inputs, predictions, custom_string="after-prediction")
|
||||||
saved_images_prediction = True
|
saved_images_prediction = True
|
||||||
if use_dropout:
|
if use_bayesian:
|
||||||
decoded_predictions = decode_func_dropout(predictions)
|
decoded_predictions = callables.decode_func_dropout(predictions)
|
||||||
observations = get_observations_func(decoded_predictions)
|
observations = callables.get_observations_func(decoded_predictions)
|
||||||
for entropy_threshold in entropy_thresholds:
|
for entropy_threshold in entropy_thresholds:
|
||||||
if use_dropout:
|
if use_bayesian:
|
||||||
save_func(observations, original_labels, filenames,
|
callables.save_func(observations, original_labels, filenames,
|
||||||
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
||||||
suffix="_prediction")
|
suffix="_prediction")
|
||||||
decoded_predictions = apply_entropy_threshold_func(observations,
|
decoded_predictions = callables.apply_entropy_threshold_func(observations,
|
||||||
entropy_threshold=entropy_threshold)
|
entropy_threshold=entropy_threshold)
|
||||||
save_func(decoded_predictions, original_labels, filenames,
|
callables.save_func(decoded_predictions, original_labels, filenames,
|
||||||
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
||||||
suffix="_entropy")
|
suffix="_entropy")
|
||||||
decoded_predictions = apply_top_k_func(decoded_predictions)
|
decoded_predictions = callables.apply_top_k_func(decoded_predictions)
|
||||||
else:
|
else:
|
||||||
decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold)
|
decoded_predictions = callables.decode_func(predictions, entropy_threshold=entropy_threshold)
|
||||||
if not saved_images_decoding:
|
if not saved_images_decoding:
|
||||||
custom_string = f"after-decoding-{entropy_threshold}" if use_entropy_threshold else "after-decoding"
|
custom_string = f"after-decoding-{entropy_threshold}" \
|
||||||
save_images(inputs, decoded_predictions, custom_string=custom_string)
|
if conf_obj.parameters.ssd_use_entropy_threshold else "after-decoding"
|
||||||
|
callables.save_images(inputs, decoded_predictions, custom_string=custom_string)
|
||||||
|
|
||||||
transformed_predictions = transform_func(decoded_predictions,
|
transformed_predictions = callables.transform_func(decoded_predictions,
|
||||||
inverse_transforms)
|
inverse_transforms)
|
||||||
save_func(transformed_predictions, original_labels, filenames,
|
callables.save_func(transformed_predictions, original_labels, filenames,
|
||||||
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
batch_nr=batch_counter, entropy_threshold=entropy_threshold,
|
||||||
suffix="_transformed")
|
suffix="_transformed")
|
||||||
|
|
||||||
|
@ -387,10 +396,7 @@ def _decode_predictions(predictions: np.ndarray,
|
||||||
def _decode_predictions_dropout(predictions: np.ndarray,
|
def _decode_predictions_dropout(predictions: np.ndarray,
|
||||||
decode_func: callable,
|
decode_func: callable,
|
||||||
image_size: int,
|
image_size: int,
|
||||||
# entropy_threshold: float,
|
|
||||||
confidence_threshold: float,
|
confidence_threshold: float,
|
||||||
# iou_threshold: float,
|
|
||||||
# top_k: int
|
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
return decode_func(
|
return decode_func(
|
||||||
y_pred=predictions,
|
y_pred=predictions,
|
||||||
|
@ -492,10 +498,11 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda
|
||||||
def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray,
|
def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray,
|
||||||
save_images: callable,
|
save_images: callable,
|
||||||
get_coco_cat_maps_func: callable,
|
get_coco_cat_maps_func: callable,
|
||||||
output_path: str, coco_path: str,
|
output_path: str,
|
||||||
image_size: int, custom_string: str) -> None:
|
conf_obj: config.Config,
|
||||||
|
custom_string: str) -> None:
|
||||||
save_images(inputs, predictions,
|
save_images(inputs, predictions,
|
||||||
output_path, coco_path, image_size,
|
output_path, conf_obj.paths.coco, conf_obj.parameters.ssd_image_size,
|
||||||
get_coco_cat_maps_func, custom_string)
|
get_coco_cat_maps_func, custom_string)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue