PyTorch - 使用 torchvision.datasets.ImageFolder

我按照以下方式构建了我的数据集:


dataset/train/0/456.jpg

dataset/train/1/456456.jpg

dataset/train/2/456.jpg

dataset/train/...


dataset/val/0/878.jpg

dataset/val/1/234.jpg

dataset/val/2/34554.jpg

dataset/val/...

所以我曾经torchvision.datasets.ImageFolder将我的数据集导入 PyTorch。然而,它似乎没有给正确的图像贴上正确的标签。我在下面添加了我的代码:


data_transforms = {

    'train': transforms.Compose(

        [transforms.Resize((176,176)),

         transforms.RandomRotation((0,360)),

         transforms.RandomHorizontalFlip(),

         transforms.RandomVerticalFlip(),

         transforms.CenterCrop(128),         

         transforms.Grayscale(),

         transforms.ToTensor(),

         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))

    ]),

    'val': transforms.Compose(

        [transforms.Resize((128,128)),

         transforms.Grayscale(),

         transforms.ToTensor(),

         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))

    ]),

}


data_dir = 'dataset'

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),

                                          data_transforms[x])

                  for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,

                                             shuffle=True, num_workers=4)

              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

class_names = image_datasets['train'].classes


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

我发现标签是错误的,使用以下函数:


def imshow(img):

    img = img / 2 + 0.5

    npimg = img.numpy()

    plt.imshow(np.transpose(npimg, (1, 2, 0)))

    plt.show()


dataiter = iter(dataloaders['val'])

images, labels = dataiter.next()


imshow(torchvision.utils.make_grid(images))

print(labels)

使用显示的图像和标签,我手动检查它们是否正确。不幸的是,标签与图像不对应。有人能告诉我我做错了什么吗?


陪伴而非守候
浏览 687回答 2
2回答

当年话下

有人帮我解决了这个问题。ImageFolder 创建自己的内部标签。通过打印,image_datasets['train'].class_to_idx您可以看到哪个标签与哪个内部标签配对。使用这本词典,您可以追溯原始标签。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python