From 5c9eaab6e4d0a81c9e255fdd7642c17ffdcba40f Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Mon, 15 Jul 2019 11:03:09 +0200 Subject: [PATCH] Fixed function calls Signed-off-by: Jim Martens --- src/twomartens/masterthesis/ssd.py | 48 ++++++++++++++---------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 20fec31..82314c3 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -170,27 +170,27 @@ def predict(generator: callable, output_file, label_output_file = _predict_prepare_paths(output_path, use_dropout) _predict_loop(generator, use_dropout, steps_per_epoch, - functools.partial(_predict_dropout_step, - model=model, - get_observations_func=_get_observations, - batch_size=batch_size, - forward_passes_per_image=forward_passes_per_image), - functools.partial(_predict_vanilla_step, model=model), - functools.partial(_transform_predictions, - decode_func=ssd_output_decoder.decode_detections_fast, - inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms, - image_size=image_size), - functools.partial(_save_predictions, - output_file=output_file, - label_output_file=label_output_file, - nr_digits=nr_digits), - functools.partial(_predict_save_images, - save_images=debug.save_ssd_train_images, - get_coco_cat_maps_func=coco_utils.get_coco_category_maps, - output_path=output_path, - coco_path=coco_path, - image_size=image_size) - ) + dropout_step=functools.partial(_predict_dropout_step, + model=model, + get_observations_func=_get_observations, + batch_size=batch_size, + forward_passes_per_image=forward_passes_per_image), + vanilla_step=functools.partial(_predict_vanilla_step, model=model), + save_images=functools.partial(_predict_save_images, + save_images=debug.save_ssd_train_images, + get_coco_cat_maps_func=coco_utils.get_coco_category_maps, + output_path=output_path, + coco_path=coco_path, + image_size=image_size), + transform_func=functools.partial( + _transform_predictions, + decode_func=ssd_output_decoder.decode_detections_fast, + inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms, + image_size=image_size), + save_func=functools.partial(_save_predictions, + output_file=output_file, + label_output_file=label_output_file, + nr_digits=nr_digits)) def train(train_generator: callable, @@ -274,11 +274,7 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, else: predictions = vanilla_step(inputs) - # save_images(inputs, predictions) - print(( - f"Input shape: {inputs.shape}" - f"Predictions shape: {predictions.shape}" - )) + save_images(inputs, predictions) transformed_predictions = transform_func(predictions, inverse_transforms) save_func(transformed_predictions, original_labels, filenames, batch_nr=batch_counter)