我正在尝试将 Google Storage 上的目录解析为字符串,但我不断收到错误。我想找到每个文件的目录并将目录名称的数字编码作为数据集返回。这在使用 LabelEncoder 的 sklearn 中是微不足道的,但我在 Tensorflow 中做这件事时遇到了麻烦。
CLASS_NAMES = [b'class_1', b'class_2', b'class_3']
labeler = tfds.features.ClassLabel(names=CLASS_NAMES)
def parse_filenames(filename):
label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')
label = label.values[-2]
# Problem is in the two lines below
position_feature = tf.feature_column.categorical_column_with_vocabulary_list('label_names', CLASS_NAMES)
label = tf.io.parse_example(label, features=position_feature)
return label
folder = b'gs://<bucket>/train/*/*.jpg'
filenames_dataset = tf.data.Dataset.list_files(folder)
label_dataset = filenames_dataset.map(parse_filenames)
next(iter(label_dataset))
我得到一个错误ValueError: dictionary update sequence element #0 has length 16; 2 is required
如果我删除“# Problem is here”注释下的两行,它工作正常,除了它返回一个字符串而不是一个整数。我已经尝试过其他非张量流选项,例如 <list_name>.index(label),但那些当然会失败,因为一切都是张量而不是字符串。还有另一种方法吗?
千巷猫影
慕森卡
相关分类