From d71a135c4bdaaaeb2fc9e414b19dc3bf83320846 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Sun, 11 Aug 2019 20:17:52 +0200 Subject: [PATCH] Optimized retrieving detections in dropout step Signed-off-by: Jim Martens --- src/twomartens/masterthesis/ssd.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 22d75e6..5e9d850 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -335,14 +335,13 @@ def _predict_dropout_step(inputs: np.ndarray, model: tf.keras.models.Model, get_observations_func: callable, batch_size: int, forward_passes_per_image: int) -> np.ndarray: - detections = [[] for _ in range(batch_size)] + detections = [np.zeros((8732 * forward_passes_per_image, 73)) for _ in range(batch_size)] - for _ in range(forward_passes_per_image): + for forward_pass in range(forward_passes_per_image): predictions = model.predict_on_batch(inputs) for i in range(batch_size): - batch_item = predictions[i] - detections[i].extend(batch_item) + detections[i][forward_pass * 8732:forward_pass * 8732 + 8732] = predictions[i] observations = np.asarray(get_observations_func(detections)) @@ -400,7 +399,7 @@ def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray, get_coco_cat_maps_func, custom_string) -def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[np.ndarray]]: +def _get_observations(detections: Sequence[np.ndarray]) -> List[List[np.ndarray]]: batch_size = len(detections) observations = [[] for _ in range(batch_size)] print(f"batch size: {batch_size}")