@ -170,27 +170,27 @@ def predict(generator: callable,
|
|||||||
output_file, label_output_file = _predict_prepare_paths(output_path, use_dropout)
|
output_file, label_output_file = _predict_prepare_paths(output_path, use_dropout)
|
||||||
|
|
||||||
_predict_loop(generator, use_dropout, steps_per_epoch,
|
_predict_loop(generator, use_dropout, steps_per_epoch,
|
||||||
functools.partial(_predict_dropout_step,
|
dropout_step=functools.partial(_predict_dropout_step,
|
||||||
model=model,
|
model=model,
|
||||||
get_observations_func=_get_observations,
|
get_observations_func=_get_observations,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
forward_passes_per_image=forward_passes_per_image),
|
forward_passes_per_image=forward_passes_per_image),
|
||||||
functools.partial(_predict_vanilla_step, model=model),
|
vanilla_step=functools.partial(_predict_vanilla_step, model=model),
|
||||||
functools.partial(_transform_predictions,
|
save_images=functools.partial(_predict_save_images,
|
||||||
decode_func=ssd_output_decoder.decode_detections_fast,
|
|
||||||
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms,
|
|
||||||
image_size=image_size),
|
|
||||||
functools.partial(_save_predictions,
|
|
||||||
output_file=output_file,
|
|
||||||
label_output_file=label_output_file,
|
|
||||||
nr_digits=nr_digits),
|
|
||||||
functools.partial(_predict_save_images,
|
|
||||||
save_images=debug.save_ssd_train_images,
|
save_images=debug.save_ssd_train_images,
|
||||||
get_coco_cat_maps_func=coco_utils.get_coco_category_maps,
|
get_coco_cat_maps_func=coco_utils.get_coco_category_maps,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
coco_path=coco_path,
|
coco_path=coco_path,
|
||||||
image_size=image_size)
|
image_size=image_size),
|
||||||
)
|
transform_func=functools.partial(
|
||||||
|
_transform_predictions,
|
||||||
|
decode_func=ssd_output_decoder.decode_detections_fast,
|
||||||
|
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms,
|
||||||
|
image_size=image_size),
|
||||||
|
save_func=functools.partial(_save_predictions,
|
||||||
|
output_file=output_file,
|
||||||
|
label_output_file=label_output_file,
|
||||||
|
nr_digits=nr_digits))
|
||||||
|
|
||||||
|
|
||||||
def train(train_generator: callable,
|
def train(train_generator: callable,
|
||||||
@ -274,11 +274,7 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
|||||||
else:
|
else:
|
||||||
predictions = vanilla_step(inputs)
|
predictions = vanilla_step(inputs)
|
||||||
|
|
||||||
# save_images(inputs, predictions)
|
save_images(inputs, predictions)
|
||||||
print((
|
|
||||||
f"Input shape: {inputs.shape}"
|
|
||||||
f"Predictions shape: {predictions.shape}"
|
|
||||||
))
|
|
||||||
transformed_predictions = transform_func(predictions, inverse_transforms)
|
transformed_predictions = transform_func(predictions, inverse_transforms)
|
||||||
save_func(transformed_predictions, original_labels, filenames,
|
save_func(transformed_predictions, original_labels, filenames,
|
||||||
batch_nr=batch_counter)
|
batch_nr=batch_counter)
|
||||||
|
|||||||
Reference in New Issue
Block a user