Optimized retrieving detections in dropout step
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
parent
7be3ee5d4f
commit
d71a135c4b
|
@ -335,14 +335,13 @@ def _predict_dropout_step(inputs: np.ndarray, model: tf.keras.models.Model,
|
|||
get_observations_func: callable,
|
||||
batch_size: int, forward_passes_per_image: int) -> np.ndarray:
|
||||
|
||||
detections = [[] for _ in range(batch_size)]
|
||||
detections = [np.zeros((8732 * forward_passes_per_image, 73)) for _ in range(batch_size)]
|
||||
|
||||
for _ in range(forward_passes_per_image):
|
||||
for forward_pass in range(forward_passes_per_image):
|
||||
predictions = model.predict_on_batch(inputs)
|
||||
|
||||
for i in range(batch_size):
|
||||
batch_item = predictions[i]
|
||||
detections[i].extend(batch_item)
|
||||
detections[i][forward_pass * 8732:forward_pass * 8732 + 8732] = predictions[i]
|
||||
|
||||
observations = np.asarray(get_observations_func(detections))
|
||||
|
||||
|
@ -400,7 +399,7 @@ def _predict_save_images(inputs: np.ndarray, predictions: np.ndarray,
|
|||
get_coco_cat_maps_func, custom_string)
|
||||
|
||||
|
||||
def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[np.ndarray]]:
|
||||
def _get_observations(detections: Sequence[np.ndarray]) -> List[List[np.ndarray]]:
|
||||
batch_size = len(detections)
|
||||
observations = [[] for _ in range(batch_size)]
|
||||
print(f"batch size: {batch_size}")
|
||||
|
|
Loading…
Reference in New Issue