Finished conversion of training functionality to keras
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -25,6 +25,7 @@ Functions:
|
|||||||
prepare(...): prepares the SceneNet ground truth data
|
prepare(...): prepares the SceneNet ground truth data
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def train(args: argparse.Namespace) -> None:
|
def train(args: argparse.Namespace) -> None:
|
||||||
@ -39,7 +40,6 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
|||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.ops import summary_ops_v2
|
|
||||||
|
|
||||||
from twomartens.masterthesis import data
|
from twomartens.masterthesis import data
|
||||||
from twomartens.masterthesis import ssd
|
from twomartens.masterthesis import ssd
|
||||||
@ -55,35 +55,58 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
|||||||
os.makedirs(weights_path, exist_ok=True)
|
os.makedirs(weights_path, exist_ok=True)
|
||||||
|
|
||||||
# load prepared ground truth
|
# load prepared ground truth
|
||||||
with open(f"{args.ground_truth_path}/photo_paths.bin", "rb") as file:
|
with open(f"{args.ground_truth_path_train}/photo_paths.bin", "rb") as file:
|
||||||
file_names_photos = pickle.load(file)
|
file_names_train = pickle.load(file)
|
||||||
with open(f"{args.ground_truth_path}/instances.bin", "rb") as file:
|
with open(f"{args.ground_truth_path_train}/instances.bin", "rb") as file:
|
||||||
instances = pickle.load(file)
|
instances_train = pickle.load(file)
|
||||||
|
with open(f"{args.ground_truth_path_val}/photo_paths.bin", "rb") as file:
|
||||||
|
file_names_val = pickle.load(file)
|
||||||
|
with open(f"{args.ground_truth_path_val}/instances.bin", "rb") as file:
|
||||||
|
instances_val = pickle.load(file)
|
||||||
|
|
||||||
scenenet_data, nr_digits, length_dataset = \
|
# model
|
||||||
data.load_scenenet_data(file_names_photos, instances, args.coco_path,
|
if use_dropout:
|
||||||
batch_size=batch_size, num_epochs=args.num_epochs,
|
ssd_model = ssd.DropoutSSD(mode='training', weights_path=pre_trained_weights_file)
|
||||||
|
else:
|
||||||
|
ssd_model = ssd.SSD(mode='training', weights_path=pre_trained_weights_file)
|
||||||
|
|
||||||
|
train_generator, train_length = \
|
||||||
|
data.load_scenenet_data(file_names_train, instances_train, args.coco_path,
|
||||||
|
predictor_sizes=ssd_model.predictor_sizes,
|
||||||
|
batch_size=batch_size,
|
||||||
resized_shape=(image_size, image_size),
|
resized_shape=(image_size, image_size),
|
||||||
mode="training")
|
mode="training")
|
||||||
del file_names_photos, instances
|
val_generator, val_length = \
|
||||||
|
data.load_scenenet_data(file_names_val, instances_val, args.coco_path,
|
||||||
|
predictor_sizes=ssd_model.predictor_sizes,
|
||||||
|
batch_size=batch_size,
|
||||||
|
resized_shape=(image_size, image_size),
|
||||||
|
mode="validation")
|
||||||
|
del file_names_train, instances_train, file_names_val, instances_val
|
||||||
|
|
||||||
use_summary_writer = summary_ops_v2.create_file_writer(
|
nr_batches_train = int(math.ceil(train_length / float(batch_size)))
|
||||||
f"{args.summary_path}/train/{args.network}/{args.iteration}"
|
nr_batches_val = int(math.ceil(val_length / float(batch_size)))
|
||||||
|
|
||||||
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||||||
|
log_dir=f"{args.summary_path}/train/{args.network}/{args.iteration}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.debug:
|
history = ssd.train_keras(
|
||||||
with use_summary_writer.as_default():
|
train_generator,
|
||||||
ssd.train(scenenet_data, args.iteration, use_dropout, length_dataset,
|
nr_batches_train,
|
||||||
weights_prefix=weights_path,
|
val_generator,
|
||||||
weights_path=pre_trained_weights_file, batch_size=batch_size,
|
nr_batches_val,
|
||||||
nr_epochs=args.num_epochs,
|
ssd_model,
|
||||||
verbose=args.verbose)
|
weights_path,
|
||||||
else:
|
args.iteration,
|
||||||
ssd.train(scenenet_data, args.iteration, use_dropout, length_dataset,
|
initial_epoch=0,
|
||||||
weights_prefix=weights_path,
|
nr_epochs=args.num_epochs,
|
||||||
weights_path=pre_trained_weights_file, batch_size=batch_size,
|
lr=0.001,
|
||||||
nr_epochs=args.num_epochs,
|
tensorboard_callback=tensorboard_callback
|
||||||
verbose=args.verbose)
|
)
|
||||||
|
|
||||||
|
with open(f"{args.summary_path}/train/{args.network}/{args.iteration}/history", "wb") as file:
|
||||||
|
pickle.dump(history, file)
|
||||||
|
|
||||||
|
|
||||||
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||||
|
|||||||
@ -85,7 +85,10 @@ def _build_train(parser: argparse.ArgumentParser) -> None:
|
|||||||
def _build_ssd_train(parser: argparse.ArgumentParser) -> None:
|
def _build_ssd_train(parser: argparse.ArgumentParser) -> None:
|
||||||
parser.add_argument("--coco_path", type=str, help="the path to the COCO data set")
|
parser.add_argument("--coco_path", type=str, help="the path to the COCO data set")
|
||||||
parser.add_argument("--weights_path", type=str, help="path to the weights directory")
|
parser.add_argument("--weights_path", type=str, help="path to the weights directory")
|
||||||
parser.add_argument("--ground_truth_path", type=str, help="path to the prepared ground truth directory")
|
parser.add_argument("--ground_truth_path_train", type=str,
|
||||||
|
help="path to the prepared ground truth directory for training")
|
||||||
|
parser.add_argument("--ground_truth_path_val", type=str,
|
||||||
|
help="path to the prepared ground truth directory for validation")
|
||||||
parser.add_argument("--summary_path", type=str, help="path to the summaries directory")
|
parser.add_argument("--summary_path", type=str, help="path to the summaries directory")
|
||||||
parser.add_argument("num_epochs", type=int, help="the number of epochs to train", default=80)
|
parser.add_argument("num_epochs", type=int, help="the number of epochs to train", default=80)
|
||||||
parser.add_argument("iteration", type=int, help="the training iteration")
|
parser.add_argument("iteration", type=int, help="the training iteration")
|
||||||
|
|||||||
Reference in New Issue
Block a user