diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index e0bf291..8cbe9a9 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -333,10 +333,14 @@ def _load_images_ssd_callback(resized_shape: Sequence[int]) \ image_data, _labels = data image = tf.image.decode_image(image_data, channels=3, dtype=tf.float32) image_shape = tf.shape(image) - x_reverse = tf.expand_dims(tf.cast(image_shape[0], dtype=tf.float32) / resized_shape[0], - axis=0) - y_reverse = tf.expand_dims(tf.cast(image_shape[1], dtype=tf.float32) / resized_shape[1], - axis=0) + x_reverse = tf.broadcast_to( + tf.expand_dims(tf.expand_dims(tf.cast(image_shape[0], dtype=tf.float32) / resized_shape[0], + axis=0), axis=0), + [tf.shape(_labels)[0], 1]) + y_reverse = tf.broadcast_to( + tf.expand_dims(tf.expand_dims(tf.cast(image_shape[1], dtype=tf.float32) / resized_shape[1], + axis=0), axis=0), + [tf.shape(_labels)[0], 1]) _labels = tf.concat([_labels, x_reverse, y_reverse], axis=1) image = tf.reshape(image, [image_shape[0], image_shape[1], 3]) image_resized = tf.image.resize_images(image, [resized_shape[0], resized_shape[1]])