在 python 中,我有一个迭代器返回一个固定范围内的无限索引字符串,[0, N]称为Sampler. 实际上我有一个列表,它们所做的只是返回范围内的索引[0, N_0], [N_0, N_1], ..., [N_{n-1}, N_n].
我现在要做的是首先根据范围的长度选择这些迭代器中的一个,所以我有一个weights列表[N_0, N_1 - N_0, ...],我选择其中一个:
iterator_idx = random.choices(range(len(weights)), weights=weights/weights.sum())[0]
接下来,我想要做的是创建一个迭代器,它随机选择一个迭代器并选择一批M样本。
class BatchSampler:
def __init__(self, M):
self.M = M
self.weights = [weight_list]
self.samplers = [list_of_iterators]
]
self._batch_samplers = [
self.batch_sampler(sampler) for sampler in self.samplers
]
def batch_sampler(self, sampler):
batch = []
for batch_idx in sampler:
batch.append(batch_idx)
if len(batch) == self.M:
yield batch
if len(batch) > 0:
yield batch
def __iter__(self):
# First select one of the datasets.
iterator_idx = random.choices(
range(len(self.weights)), weights=self.weights / self.weights.sum()
)[0]
return self._batch_samplers[iterator_idx]
问题在于它似乎iter()只被调用一次,所以只iterator_idx选择了第一次。显然这是错误的......解决这个问题的方法是什么?
当您在 pytorch 中有多个数据集时,可能会出现这种情况,但您只想从其中一个数据集中采样批次。
ibeautiful
相关分类