我正在尝试分割 CIFAR10 的训练数据,因此训练集的最后 5000 个用于验证。我的代码
size = len(CIFAR10_training)
dataset_indices = list(range(size))
val_index = int(np.floor(0.9 * size))
train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = train_sampler)
valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = val_sampler)
print(len(train_dataloader.dataset),len(valid_dataloader.dataset),
但最后一个打印语句打印 50000 和 10000。当我打印 train_idx 和 val_idx 时,它不应该是 45000 和 5000 它打印正确的值([0:44999],[45000:49999] 我的代码有什么问题吗
阿波罗的战车