我正在 python 中使用 Tensorflow 进行对象检测。
我想使用张量流输入管道来加载批量输入数据。问题是图像中的对象数量是可变的。
想象一下我想做以下事情。注释是图像文件名及其包含的边界框的数组。标签被排除在外。每个边界框由四个数字表示。
import tensorflow as tf
@tf.function()
def prepare_sample(annotation):
annotation_parts = tf.strings.split(annotation, sep=' ')
image_file_name = annotation_parts[0]
image_file_path = tf.strings.join(["/images/", image_file_name])
depth_image = tf.io.read_file(image_file_path)
bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])
return depth_image, bboxes
annotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']
dataset = tf.data.Dataset.from_tensor_slices(annotations)
dataset = dataset.shuffle(len(annotations))
dataset = dataset.map(prepare_sample)
dataset = dataset.batch(16)
for image, bboxes in dataset:
pass
在上面的示例中,image1 包含单个对象,而 image2 包含两个对象。我收到以下错误:
InvalidArgumentError:无法将张量添加到批次:元素数量不匹配。形状为:[张量]:[1,4],[批次]:[2,4]
这就说得通了。我正在寻找从映射函数返回不同长度数组的方法。我能做些什么?
谢谢你!
编辑:我想我找到了解决方案;我不再收到错误。我dataset.batch(16)改为dataset.padded_batch(16).
绝地无双
相关分类