@ -39,11 +39,11 @@ def main() -> None:
|
|||||||
|
|
||||||
train_parser = sub_parsers.add_parser("train", help="Train a network")
|
train_parser = sub_parsers.add_parser("train", help="Train a network")
|
||||||
test_parser = sub_parsers.add_parser("test", help="Test a network")
|
test_parser = sub_parsers.add_parser("test", help="Test a network")
|
||||||
use_parser = sub_parsers.add_parser("use", help="Use a network")
|
val_parser = sub_parsers.add_parser("val", help="Validate a network")
|
||||||
|
|
||||||
# build sub parsers
|
# build sub parsers
|
||||||
_build_train(train_parser)
|
_build_train(train_parser)
|
||||||
_build_use(use_parser)
|
_build_val(val_parser)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -51,8 +51,8 @@ def main() -> None:
|
|||||||
_train(args)
|
_train(args)
|
||||||
elif args.action == "test":
|
elif args.action == "test":
|
||||||
_test(args)
|
_test(args)
|
||||||
elif args.action == "use":
|
elif args.action == "val":
|
||||||
_use(args)
|
_val(args)
|
||||||
|
|
||||||
|
|
||||||
def _build_train(parser: argparse.ArgumentParser) -> None:
|
def _build_train(parser: argparse.ArgumentParser) -> None:
|
||||||
@ -67,7 +67,7 @@ def _build_train(parser: argparse.ArgumentParser) -> None:
|
|||||||
_build_auto_encoder_train(auto_encoder_parser)
|
_build_auto_encoder_train(auto_encoder_parser)
|
||||||
|
|
||||||
|
|
||||||
def _build_use(parser: argparse.ArgumentParser) -> None:
|
def _build_val(parser: argparse.ArgumentParser) -> None:
|
||||||
sub_parsers = parser.add_subparsers(dest="network")
|
sub_parsers = parser.add_subparsers(dest="network")
|
||||||
sub_parsers.required = True
|
sub_parsers.required = True
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ def _build_use(parser: argparse.ArgumentParser) -> None:
|
|||||||
auto_encoder_parser = sub_parsers.add_parser("auto_encoder", help="Auto-encoder network")
|
auto_encoder_parser = sub_parsers.add_parser("auto_encoder", help="Auto-encoder network")
|
||||||
|
|
||||||
# build sub parsers
|
# build sub parsers
|
||||||
_build_auto_encoder_use(auto_encoder_parser)
|
_build_auto_encoder_val(auto_encoder_parser)
|
||||||
|
|
||||||
|
|
||||||
def _build_auto_encoder_train(parser: argparse.ArgumentParser) -> None:
|
def _build_auto_encoder_train(parser: argparse.ArgumentParser) -> None:
|
||||||
@ -87,13 +87,13 @@ def _build_auto_encoder_train(parser: argparse.ArgumentParser) -> None:
|
|||||||
parser.add_argument("iteration", type=int, help="the training iteration")
|
parser.add_argument("iteration", type=int, help="the training iteration")
|
||||||
|
|
||||||
|
|
||||||
def _build_auto_encoder_use(parser: argparse.ArgumentParser) -> None:
|
def _build_auto_encoder_val(parser: argparse.ArgumentParser) -> None:
|
||||||
parser.add_argument("--coco_path", type=str, help="the path to the COCO data set")
|
parser.add_argument("--coco_path", type=str, help="the path to the COCO data set")
|
||||||
parser.add_argument("--weights_path", type=str, help="path to the weights directory")
|
parser.add_argument("--weights_path", type=str, help="path to the weights directory")
|
||||||
parser.add_argument("--summary_path", type=str, help="path to the summaries directory")
|
parser.add_argument("--summary_path", type=str, help="path to the summaries directory")
|
||||||
parser.add_argument("category", type=int, help="the COCO category to use")
|
parser.add_argument("category", type=int, help="the COCO category to validate")
|
||||||
parser.add_argument("category_trained", type=int, help="the trained COCO category")
|
parser.add_argument("category_trained", type=int, help="the trained COCO category")
|
||||||
parser.add_argument("iteration", type=int, help="the use iteration")
|
parser.add_argument("iteration", type=int, help="the validation iteration")
|
||||||
parser.add_argument("iteration_trained", type=int, help="the training iteration")
|
parser.add_argument("iteration_trained", type=int, help="the training iteration")
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ def _test(args: argparse.Namespace) -> None:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def _use(args: argparse.Namespace) -> None:
|
def _val(args: argparse.Namespace) -> None:
|
||||||
from twomartens.masterthesis import data
|
from twomartens.masterthesis import data
|
||||||
from twomartens.masterthesis.aae import run
|
from twomartens.masterthesis.aae import run
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -123,10 +123,10 @@ def _use(args: argparse.Namespace) -> None:
|
|||||||
category = args.category
|
category = args.category
|
||||||
category_trained = args.category_trained
|
category_trained = args.category_trained
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
coco_data = data.load_coco_train(coco_path, category, num_epochs=1,
|
coco_data = data.load_coco_val(coco_path, category, num_epochs=1,
|
||||||
batch_size=batch_size, resized_shape=(256, 256))
|
batch_size=batch_size, resized_shape=(256, 256))
|
||||||
use_summary_writer = summary_ops_v2.create_file_writer(
|
use_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
f"{args.summary_path}/use/category-{category}/{args.iteration}"
|
f"{args.summary_path}/val/category-{category}/{args.iteration}"
|
||||||
)
|
)
|
||||||
with use_summary_writer.as_default():
|
with use_summary_writer.as_default():
|
||||||
run.run_simple(coco_data, iteration=args.iteration_trained, debug=args.debug,
|
run.run_simple(coco_data, iteration=args.iteration_trained, debug=args.debug,
|
||||||
|
|||||||
Reference in New Issue
Block a user