Added config option to test only pre-trained version of SSD

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
Jim Martens 2019-07-17 15:25:04 +02:00
parent f34eb93e71
commit cd79be4307
2 changed files with 13 additions and 5 deletions

View File

@ -265,12 +265,13 @@ def _ssd_test(args: argparse.Namespace) -> None:
batch_size, image_size, learning_rate, \
forward_passes_per_image, nr_classes, iou_threshold, dropout_rate, \
use_entropy_threshold, entropy_threshold_min, entropy_threshold_max, \
top_k, nr_trajectories, \
top_k, nr_trajectories, test_pretrained, \
coco_path, output_path, weights_path, ground_truth_path = _ssd_test_get_config_values(args, conf.get_property)
use_dropout = _ssd_is_dropout(args)
output_path, checkpoint_path, weights_file = _ssd_test_prepare_paths(args, output_path, weights_path)
output_path, checkpoint_path, weights_file = _ssd_test_prepare_paths(args, output_path,
weights_path, test_pretrained)
file_names, instances = _ssd_test_load_gt(ground_truth_path)
@ -508,7 +509,7 @@ def _ssd_test_get_config_values(args: argparse.Namespace,
config_get: Callable[[str], Union[str, float, int, bool]]
) -> Tuple[int, int, float, int, int, float, float,
bool, float, float,
int, int,
int, int, bool,
str, str, str, str]:
batch_size = config_get("Parameters.batch_size")
@ -523,6 +524,7 @@ def _ssd_test_get_config_values(args: argparse.Namespace,
entropy_threshold_max = config_get("Parameters.ssd_entropy_threshold_max")
top_k = config_get("Parameters.ssd_top_k")
nr_trajectories = config_get("Parameters.nr_trajectories")
test_pretrained = config_get("Parameters.ssd_test_pretrained")
coco_path = config_get("Paths.coco")
output_path = config_get("Paths.output")
@ -547,6 +549,7 @@ def _ssd_test_get_config_values(args: argparse.Namespace,
#
top_k,
nr_trajectories,
test_pretrained,
#
coco_path,
output_path,
@ -598,12 +601,16 @@ def _ssd_train_prepare_paths(args: argparse.Namespace,
def _ssd_test_prepare_paths(args: argparse.Namespace,
output_path: str, weights_path: str) -> Tuple[str, str, str]:
output_path: str, weights_path: str,
test_pretrained: bool) -> Tuple[str, str, str]:
import os
output_path = f"{output_path}/{args.network}/test/{args.iteration}/"
checkpoint_path = f"{weights_path}/{args.network}/train/{args.train_iteration}"
weights_file = f"{checkpoint_path}/ssd300_weights.h5"
if test_pretrained:
weights_file = f"{weights_path}/{args.network}/VGG_coco_SSD_300x300_iter_400000.h5"
else:
weights_file = f"{checkpoint_path}/ssd300_weights.h5"
os.makedirs(output_path, exist_ok=True)

View File

@ -62,6 +62,7 @@ _CONFIG_PROPS = {
"ssd_use_entropy_threshold": (bool, False),
"ssd_entropy_threshold_min": (float, "0.1"),
"ssd_entropy_threshold_max": (float, "2.5"),
"ssd_test_pretrained": (bool, False),
"nr_trajectories": (int, "-1")
}
}