diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 50cc42f..729a9c0 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -170,16 +170,26 @@ def predict(dataset: tf.data.Dataset, """ if weights_path is None and checkpoint_path is None: raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.") - - checkpointables = {} + + # model if use_dropout: - checkpointables.update({ - 'ssd': DropoutSSD(mode='training', weights_path=weights_path) - }) + ssd = DropoutSSD(mode='training', weights_path=weights_path) else: - checkpointables.update({ - 'ssd': SSD(mode='inference_fast', weights_path=weights_path) - }) + ssd = SSD(mode='training', weights_path=weights_path) + + checkpointables = { + 'ssd': ssd.model, + 'learning_rate_var': K.variable(0), + } + + checkpointables.update({ + # optimizer + 'ssd_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], + beta1=0.5, beta2=0.999), + # global step counter + 'global_step': tf.train.get_or_create_global_step(), + 'epoch_var': K.variable(-1, dtype=tf.int64) + }) if checkpoint_path is not None: # checkpoint @@ -188,7 +198,7 @@ def predict(dataset: tf.data.Dataset, checkpoint.restore(latest_checkpoint) outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, - nr_digits, **checkpointables) + nr_digits, checkpointables['ssd']) if verbose: print((