Improved readability of code

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
Jim Martens 2019-09-13 13:00:14 +02:00
parent 01f43651b5
commit d877da3ef3
1 changed files with 99 additions and 92 deletions

View File

@ -37,6 +37,7 @@ from typing import Union
import math import math
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from attributedict.collections import AttributeDict
from twomartens.masterthesis import config from twomartens.masterthesis import config
from twomartens.masterthesis import debug from twomartens.masterthesis import debug
@ -116,6 +117,7 @@ def get_model(use_bayesian: bool,
def get_loss_func() -> callable: def get_loss_func() -> callable:
"""Returns loss function for SSD."""
return keras_ssd_loss.SSDLoss().compute_loss return keras_ssd_loss.SSDLoss().compute_loss
@ -164,21 +166,27 @@ def predict(generator: callable,
""" """
output_file, label_output_file = _predict_prepare_paths(output_path, use_bayesian) output_file, label_output_file = _predict_prepare_paths(output_path, use_bayesian)
_predict_loop(generator, use_bayesian, steps_per_epoch, _predict_loop(
dropout_step=functools.partial( generator=generator,
use_bayesian=use_bayesian,
conf_obj=conf_obj,
steps_per_epoch=steps_per_epoch,
callables=AttributeDict({
"dropout_step": functools.partial(
_predict_dropout_step, _predict_dropout_step,
model=model, model=model,
batch_size=conf_obj.parameters.batch_size, batch_size=conf_obj.parameters.batch_size,
forward_passes_per_image=conf_obj.parameters.ssd_forward_passes_per_image forward_passes_per_image=conf_obj.parameters.ssd_forward_passes_per_image
), ),
vanilla_step=functools.partial(_predict_vanilla_step, model=model), "vanilla_step": functools.partial(_predict_vanilla_step, model=model),
save_images=functools.partial(_predict_save_images, "save_images": functools.partial(
_predict_save_images,
save_images=debug.save_ssd_train_images, save_images=debug.save_ssd_train_images,
get_coco_cat_maps_func=coco_utils.get_coco_category_maps, get_coco_cat_maps_func=coco_utils.get_coco_category_maps,
output_path=output_path, output_path=output_path,
coco_path=conf_obj.paths.coco, conf_obj=conf_obj
image_size=conf_obj.parameters.ssd_image_size), ),
decode_func=functools.partial( "decode_func": functools.partial(
_decode_predictions, _decode_predictions,
decode_func=ssd_output_decoder.decode_detections, decode_func=ssd_output_decoder.decode_detections,
image_size=conf_obj.parameters.ssd_image_size, image_size=conf_obj.parameters.ssd_image_size,
@ -186,34 +194,35 @@ def predict(generator: callable,
iou_threshold=conf_obj.parameters.ssd_iou_threshold, iou_threshold=conf_obj.parameters.ssd_iou_threshold,
top_k=conf_obj.parameters.ssd_top_k top_k=conf_obj.parameters.ssd_top_k
), ),
decode_func_dropout=functools.partial( "decode_func_dropout": functools.partial(
_decode_predictions_dropout, _decode_predictions_dropout,
decode_func=ssd_output_decoder.decode_detections_dropout, decode_func=ssd_output_decoder.decode_detections_dropout,
image_size=conf_obj.parameters.ssd_image_size, image_size=conf_obj.parameters.ssd_image_size,
confidence_threshold=conf_obj.parameters.ssd_confidence_threshold, confidence_threshold=conf_obj.parameters.ssd_confidence_threshold,
), ),
apply_entropy_threshold_func=functools.partial( "apply_entropy_threshold_func": functools.partial(
_apply_entropy_filtering, _apply_entropy_filtering,
confidence_threshold=conf_obj.parameters.ssd_confidence_threshold, confidence_threshold=conf_obj.parameters.ssd_confidence_threshold,
nr_classes=conf_obj.parameters.nr_classes, nr_classes=conf_obj.parameters.nr_classes,
iou_threshold=conf_obj.parameters.ssd_iou_threshold, iou_threshold=conf_obj.parameters.ssd_iou_threshold,
use_nms=conf_obj.parameters.ssd_use_nms use_nms=conf_obj.parameters.ssd_use_nms
), ),
apply_top_k_func=functools.partial( "apply_top_k_func": functools.partial(
_apply_top_k, _apply_top_k,
top_k=conf_obj.parameters.ssd_top_k top_k=conf_obj.parameters.ssd_top_k
), ),
get_observations_func=_get_observations, "get_observations_func": _get_observations,
transform_func=functools.partial( "transform_func": functools.partial(
_transform_predictions, _transform_predictions,
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms), inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms),
save_func=functools.partial(_save_predictions, "save_func": functools.partial(
_save_predictions,
output_file=output_file, output_file=output_file,
label_output_file=label_output_file, label_output_file=label_output_file,
nr_digits=nr_digits), nr_digits=nr_digits
use_entropy_threshold=conf_obj.parameters.ssd_use_entropy_threshold, )
entropy_threshold_min=conf_obj.parameters.ssd_entropy_threshold_min, })
entropy_threshold_max=conf_obj.parameters.ssd_entropy_threshold_max) )
def train(train_generator: callable, def train(train_generator: callable,
@ -285,56 +294,56 @@ def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, st
return output_file, label_output_file return output_file, label_output_file
def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, def _predict_loop(generator: Generator,
dropout_step: callable, vanilla_step: callable, use_bayesian: bool,
save_images: callable, decode_func: callable, conf_obj: config.Config,
decode_func_dropout: callable, get_observations_func: callable, steps_per_epoch: int,
apply_entropy_threshold_func: callable, apply_top_k_func: callable, callables: AttributeDict) -> None:
transform_func: callable, save_func: callable,
use_entropy_threshold: bool, entropy_threshold_min: float,
entropy_threshold_max: float) -> None:
batch_counter = 0 batch_counter = 0
saved_images_prediction = False saved_images_prediction = False
saved_images_decoding = False saved_images_decoding = False
if use_entropy_threshold: if conf_obj.parameters.ssd_use_entropy_threshold:
nr_steps = math.floor((entropy_threshold_max - entropy_threshold_min) * 10) nr_steps = math.floor(
entropy_thresholds = [round(i / 10 + entropy_threshold_min, 1) for i in range(nr_steps)] (conf_obj.parameters.ssd_entropy_threshold_max - conf_obj.parameters.ssd_entropy_threshold_min) * 10
)
entropy_thresholds = [round(i / 10 + conf_obj.parameters.ssd_entropy_threshold_min, 1) for i in range(nr_steps)]
else: else:
entropy_thresholds = [0] entropy_thresholds = [0]
for inputs, filenames, inverse_transforms, original_labels in generator: for inputs, filenames, inverse_transforms, original_labels in generator:
if use_dropout: if use_bayesian:
predictions = dropout_step(inputs) predictions = callables.dropout_step(inputs)
else: else:
predictions = vanilla_step(inputs) predictions = callables.vanilla_step(inputs)
if not saved_images_prediction: if not saved_images_prediction:
save_images(inputs, predictions, custom_string="after-prediction") callables.save_images(inputs, predictions, custom_string="after-prediction")
saved_images_prediction = True saved_images_prediction = True
if use_dropout: if use_bayesian:
decoded_predictions = decode_func_dropout(predictions) decoded_predictions = callables.decode_func_dropout(predictions)
observations = get_observations_func(decoded_predictions) observations = callables.get_observations_func(decoded_predictions)
for entropy_threshold in entropy_thresholds: for entropy_threshold in entropy_thresholds:
if use_dropout: if use_bayesian:
save_func(observations, original_labels, filenames, callables.save_func(observations, original_labels, filenames,
batch_nr=batch_counter, entropy_threshold=entropy_threshold, batch_nr=batch_counter, entropy_threshold=entropy_threshold,
suffix="_prediction") suffix="_prediction")
decoded_predictions = apply_entropy_threshold_func(observations, decoded_predictions = callables.apply_entropy_threshold_func(observations,
entropy_threshold=entropy_threshold) entropy_threshold=entropy_threshold)
save_func(decoded_predictions, original_labels, filenames, callables.save_func(decoded_predictions, original_labels, filenames,
batch_nr=batch_counter, entropy_threshold=entropy_threshold, batch_nr=batch_counter, entropy_threshold=entropy_threshold,
suffix="_entropy") suffix="_entropy")
decoded_predictions = apply_top_k_func(decoded_predictions) decoded_predictions = callables.apply_top_k_func(decoded_predictions)
else: else:
decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold) decoded_predictions = callables.decode_func(predictions, entropy_threshold=entropy_threshold)
if not saved_images_decoding: if not saved_images_decoding:
custom_string = f"after-decoding-{entropy_threshold}" if use_entropy_threshold else "after-decoding" custom_string = f"after-decoding-{entropy_threshold}" \
save_images(inputs, decoded_predictions, custom_string=custom_string) if conf_obj.parameters.ssd_use_entropy_threshold else "after-decoding"
callables.save_images(inputs, decoded_predictions, custom_string=custom_string)
transformed_predictions = transform_func(decoded_predictions, transformed_predictions = callables.transform_func(decoded_predictions,
inverse_transforms) inverse_transforms)
save_func(transformed_predictions, original_labels, filenames, callables.save_func(transformed_predictions, original_labels, filenames,
batch_nr=batch_counter, entropy_threshold=entropy_threshold, batch_nr=batch_counter, entropy_threshold=entropy_threshold,
suffix="_transformed") suffix="_transformed")
@ -387,10 +396,7 @@ def _decode_predictions(predictions: np.ndarray,
def _decode_predictions_dropout(predictions: np.ndarray, def _decode_predictions_dropout(predictions: np.ndarray,
decode_func: callable, decode_func: callable,
image_size: int, image_size: int,
# entropy_threshold: float,
confidence_threshold: float, confidence_threshold: float,
# iou_threshold: float,
# top_k: int
) -> List[np.ndarray]: ) -> List[np.ndarray]:
return decode_func( return decode_func(
y_pred=predictions, y_pred=predictions,
@ -492,10 +498,11 @@ def _save_predictions(transformed_predictions: Union[np.ndarray, Sequence[np.nda
def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray, def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray,
save_images: callable, save_images: callable,
get_coco_cat_maps_func: callable, get_coco_cat_maps_func: callable,
output_path: str, coco_path: str, output_path: str,
image_size: int, custom_string: str) -> None: conf_obj: config.Config,
custom_string: str) -> None:
save_images(inputs, predictions, save_images(inputs, predictions,
output_path, coco_path, image_size, output_path, conf_obj.paths.coco, conf_obj.parameters.ssd_image_size,
get_coco_cat_maps_func, custom_string) get_coco_cat_maps_func, custom_string)