Only save summaries if configured to do so
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -228,9 +228,12 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
|||||||
nr_batches_train = int(math.floor(train_length / batch_size))
|
nr_batches_train = int(math.floor(train_length / batch_size))
|
||||||
nr_batches_val = int(math.floor(val_length / batch_size))
|
nr_batches_val = int(math.floor(val_length / batch_size))
|
||||||
|
|
||||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
if args.debug and conf.get_property("Debug.summaries"):
|
||||||
log_dir=f"{summary_path}/train/{args.network}/{args.iteration}"
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||||||
)
|
log_dir=f"{summary_path}/train/{args.network}/{args.iteration}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensorboard_callback = None
|
||||||
|
|
||||||
history = ssd.train_keras(
|
history = ssd.train_keras(
|
||||||
train_generator,
|
train_generator,
|
||||||
|
|||||||
@ -447,7 +447,7 @@ def train_keras(train_generator: callable,
|
|||||||
initial_epoch: int,
|
initial_epoch: int,
|
||||||
nr_epochs: int,
|
nr_epochs: int,
|
||||||
lr: float,
|
lr: float,
|
||||||
tensorboard_callback: tf.keras.callbacks.TensorBoard) -> tf.keras.callbacks.History:
|
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
||||||
"""
|
"""
|
||||||
Trains the SSD on the given data set using Keras functionality.
|
Trains the SSD on the given data set using Keras functionality.
|
||||||
|
|
||||||
@ -491,9 +491,10 @@ def train_keras(train_generator: callable,
|
|||||||
save_weights_only=False
|
save_weights_only=False
|
||||||
),
|
),
|
||||||
tf.keras.callbacks.TerminateOnNaN(),
|
tf.keras.callbacks.TerminateOnNaN(),
|
||||||
tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss"),
|
tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss")
|
||||||
tensorboard_callback
|
|
||||||
]
|
]
|
||||||
|
if tensorboard_callback is not None:
|
||||||
|
callbacks.append(tensorboard_callback)
|
||||||
|
|
||||||
history = ssd_model.model.fit_generator(generator=train_generator,
|
history = ssd_model.model.fit_generator(generator=train_generator,
|
||||||
epochs=nr_epochs,
|
epochs=nr_epochs,
|
||||||
|
|||||||
Reference in New Issue
Block a user