如何使用 TFRecordsDataset 批量处理任意形状的张量?
我想要批处理的所需形状将[batch_size, arbitrary , 4]用于我的盒子和[batch_size, arbitrary, 1]课程。
def decode(serialized_example):
Decodes the information of the TFRecords to image, label_coord, label_classes
Later on will also contain the Image Sequence!
:param serialized_example: Serialized Example read from the TFRecords
:return: image, label_coordinates list, label_classes list
features = {'image/shape': tf.FixedLenFeature([], tf.string),
'train/image': tf.FixedLenFeature([], tf.string),
'label/coordinates': tf.VarLenFeature(tf.float32),
'label/classes': tf.VarLenFeature(tf.string)}
features = tf.parse_single_example(serialized_example, features=features)
image_shape = tf.decode_raw(features['image/shape'], tf.int64)
image = tf.decode_raw(features['train/image'], tf.float32)
image = tf.reshape(image, image_shape)
# Contains the Bounding Box coordinates in a flattened tensor
label_coord = features['label/coordinates']
label_coord = label_coord.values
label_coord = tf.reshape(label_coord, [1, -1, 4])
# Contains the Classes of the BBox in a flattened Tensor
label_classes = features['label/classes']
label_classes = label_classes.values
label_classes = tf.reshape(label_classes, [1, -1, 1])
return image, label_coord, label_classes
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
抛出的错误是 Cannot batch tensors with different shapes in component 1. First element had shape [1,1,4] and element 1 had shape [1,7,4].