由于我下载的imagenet2012验证集所有图片都在一个文件夹,所有标签数据都在一个txt里面,因此我使用了自定义的DataSet和DataLoader进行读取。
import os from torch.utils import data from PIL import Image import torch.nn as nn from torchvision import datasets, transforms import numpy as np import matplotlib.pyplot as plt
transform=transforms.Compose([ transforms.Resize((224, 224)), transforms.CenterCrop(224), transforms.ToTensor(), ])
class MyDataSet(data.Dataset): def __init__(self,root,target_transform=None): fh = open('imagenet/caffe_ilsvrc12/val.txt', 'r') imgs = [] for line in fh: line = line.rstrip() words = line.split() words[0]=os.path.join(root, words[0]) print('img path:',words[0],'label:',words[1]) imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transforms = transform self.target_transform = target_transform def __getitem__(self, index): #print('index:',index) img_path,label = self.imgs[index] pil_img = Image.open(img_path).convert('L') if self.transforms: data = self.transforms(pil_img) else: pil_img = np.asarray(pil_img) data = torch.from_numpy(pil_img) return data,label def __len__(self): return len(self.imgs)
自定义的MyDataSet类继承于torch.utils.data.DataSet类。由于图片本身一部分是三通道的,一部分却是单通道的。因此如果不在读取的时候统一读入灰度图,就会报一个错误:
RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [1, 224, 224] at entry 25
我原本是想读入彩色图,以便在下面直接进行展示。这个错误我不知道如何解决,因此统一读入时使用
pil_img = Image.open(img_path).convert('L')
读取单通道图片,在后面的展示中显示的也就是灰度图片。
train_dataset = MyDataSet('imagenet/val') print(len(train_dataset)) valid_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
for image in valid_loader: valid_image, valid_label = image[0], image[1] print('valid_label:', valid_label) print('valid_image shape', valid_image.shape) print(valid_image[0].shape) plt.imshow(valid_image[0].squeeze(), cmap='gray') plt.show() break
valid_label: tensor([658, 283, 202, 619, 32, 758, 646, 690, 100, 546, 942, 728, 343, 969, 80, 530, 296, 412, 163, 128, 858, 702, 507, 500, 303, 478, 342, 10, 524, 703, 277, 777, 600, 806, 768, 353, 718, 981, 598, 519, 413, 817, 774, 302, 263, 366, 31, 600, 48, 986, 98, 602, 409, 39, 894, 747, 200, 384, 140, 386, 191, 952, 128, 990]) valid_image shape torch.Size([64, 1, 224, 224]) torch.Size([1, 224, 224])