Fixes dimensions of x and y reverse values

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-05-22 15:09:58 +02:00
parent 693fb4d55c
commit d398ecba04

View File

@ -333,10 +333,14 @@ def _load_images_ssd_callback(resized_shape: Sequence[int]) \
image_data, _labels = data
image = tf.image.decode_image(image_data, channels=3, dtype=tf.float32)
image_shape = tf.shape(image)
x_reverse = tf.expand_dims(tf.cast(image_shape[0], dtype=tf.float32) / resized_shape[0],
axis=0)
y_reverse = tf.expand_dims(tf.cast(image_shape[1], dtype=tf.float32) / resized_shape[1],
axis=0)
x_reverse = tf.broadcast_to(
tf.expand_dims(tf.expand_dims(tf.cast(image_shape[0], dtype=tf.float32) / resized_shape[0],
axis=0), axis=0),
[tf.shape(_labels)[0], 1])
y_reverse = tf.broadcast_to(
tf.expand_dims(tf.expand_dims(tf.cast(image_shape[1], dtype=tf.float32) / resized_shape[1],
axis=0), axis=0),
[tf.shape(_labels)[0], 1])
_labels = tf.concat([_labels, x_reverse, y_reverse], axis=1)
image = tf.reshape(image, [image_shape[0], image_shape[1], 3])
image_resized = tf.image.resize_images(image, [resized_shape[0], resized_shape[1]])