Provided checkpoint path in validation case

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-06-10 11:21:35 +02:00
parent 4b85dd8376
commit e602e7339d

View File

@ -242,6 +242,7 @@ def _ssd_val(args: argparse.Namespace) -> None:
use_dropout = False if args.network == "ssd" else True use_dropout = False if args.network == "ssd" else True
weights_file = f"{args.weights_path}/VGG_coco_SSD_300x300_iter_400000.h5" weights_file = f"{args.weights_path}/VGG_coco_SSD_300x300_iter_400000.h5"
checkpoint_path = f"{args.weights_path}/train/{args.network}/"
output_path = f"{args.output_path}/val/{args.network}/{args.iteration}/" output_path = f"{args.output_path}/val/{args.network}/{args.iteration}/"
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
@ -262,11 +263,11 @@ def _ssd_val(args: argparse.Namespace) -> None:
) )
if args.debug: if args.debug:
with use_summary_writer.as_default(): with use_summary_writer.as_default():
ssd.predict(scenenet_data, use_dropout, output_path, weights_file, nr_digits=nr_digits, ssd.predict(scenenet_data, use_dropout, output_path, weights_file, checkpoint_path,
forward_passes_per_image=forward_passes_per_image) nr_digits=nr_digits, forward_passes_per_image=forward_passes_per_image)
else: else:
ssd.predict(scenenet_data, use_dropout, output_path, weights_file, nr_digits=nr_digits, ssd.predict(scenenet_data, use_dropout, output_path, weights_file, checkpoint_path,
forward_passes_per_image=forward_passes_per_image) nr_digits=nr_digits, forward_passes_per_image=forward_passes_per_image)
def _auto_encoder_val(args: argparse.Namespace) -> None: def _auto_encoder_val(args: argparse.Namespace) -> None: