Only save summaries if configured to do so

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-04 17:08:06 +02:00
parent 662c5ec7af
commit da2b788447
2 changed files with 10 additions and 6 deletions

View File

@ -228,9 +228,12 @@ def _ssd_train(args: argparse.Namespace) -> None:
nr_batches_train = int(math.floor(train_length / batch_size))
nr_batches_val = int(math.floor(val_length / batch_size))
if args.debug and conf.get_property("Debug.summaries"):
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=f"{summary_path}/train/{args.network}/{args.iteration}"
)
else:
tensorboard_callback = None
history = ssd.train_keras(
train_generator,

View File

@ -447,7 +447,7 @@ def train_keras(train_generator: callable,
initial_epoch: int,
nr_epochs: int,
lr: float,
tensorboard_callback: tf.keras.callbacks.TensorBoard) -> tf.keras.callbacks.History:
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
"""
Trains the SSD on the given data set using Keras functionality.
@ -491,9 +491,10 @@ def train_keras(train_generator: callable,
save_weights_only=False
),
tf.keras.callbacks.TerminateOnNaN(),
tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss"),
tensorboard_callback
tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss")
]
if tensorboard_callback is not None:
callbacks.append(tensorboard_callback)
history = ssd_model.model.fit_generator(generator=train_generator,
epochs=nr_epochs,