我正在尝试在torchtext中使用BucketIterator.splits函数从csv文件中加载数据以在CNN中使用。除非我的批处理中最长的句子比最大的过滤器大小短,否则一切都正常。
在我的示例中,我使用了大小分别为3、4和5的过滤器,因此,如果最长的句子没有至少5个单词,则会出现错误。有没有一种方法可以让BucketIterator动态设置批次的填充,还可以设置最小填充长度?
这是我用于BucketIterator的代码:
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)
我希望有一种方法可以设置sort_key或类似的最小长度?
我尝试了这个,但是不起作用:
FILTER_SIZES = [3,4,5]
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text) if len(x.text) >= FILTER_SIZES[-1] else FILTER_SIZES[-1], batch_size=batch_size, repeat=False, device=device)
潇潇雨雨
墨色风雨
相关分类