在 Tensorflow 数据集管道中返回不同长度的数组

我正在 python 中使用 Tensorflow 进行对象检测。


我想使用张量流输入管道来加载批量输入数据。问题是图像中的对象数量是可变的。


想象一下我想做以下事情。注释是图像文件名及其包含的边界框的数组。标签被排除在外。每个边界框由四个数字表示。


import tensorflow as tf


@tf.function()

def prepare_sample(annotation):

    annotation_parts = tf.strings.split(annotation, sep=' ')

    image_file_name = annotation_parts[0]

    image_file_path = tf.strings.join(["/images/", image_file_name])

    depth_image = tf.io.read_file(image_file_path)

    bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])

    return depth_image, bboxes


annotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']

dataset = tf.data.Dataset.from_tensor_slices(annotations)

dataset = dataset.shuffle(len(annotations))

dataset = dataset.map(prepare_sample)

dataset = dataset.batch(16)


for image, bboxes in dataset:

  pass

在上面的示例中,image1 包含单个对象,而 image2 包含两个对象。我收到以下错误:


InvalidArgumentError:无法将张量添加到批次:元素数量不匹配。形状为:[张量]:[1,4],[批次]:[2,4]


这就说得通了。我正在寻找从映射函数返回不同长度数组的方法。我能做些什么?


谢谢你!


编辑:我想我找到了解决方案;我不再收到错误。我dataset.batch(16)改为dataset.padded_batch(16).


慕虎7371278
浏览 4431回答 1
1回答

绝地无双

dataset.batch(16)更改为 后该错误将得到解决dataset.padded_batch(16)。下面是相同的修改后的代码。import tensorflow as tf@tf.function()def prepare_sample(annotation):    annotation_parts = tf.strings.split(annotation, sep=' ')    image_file_name = annotation_parts[0]    image_file_path = tf.strings.join(["/images/", image_file_name])    depth_image = tf.io.read_file(image_file_path)    bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])    return depth_image, bboxesannotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']dataset = tf.data.Dataset.from_tensor_slices(annotations)dataset = dataset.shuffle(len(annotations))dataset = dataset.map(prepare_sample)dataset = dataset.padded_batch(16)for image, bboxes in dataset:  pass
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python