手记

使用PyTorch加载自定义的图片及其标签

由于我下载的imagenet2012验证集所有图片都在一个文件夹,所有标签数据都在一个txt里面,因此我使用了自定义的DataSet和DataLoader进行读取。

import os
from torch.utils import data
from PIL import Image
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
transform=transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
class MyDataSet(data.Dataset):
    def __init__(self,root,target_transform=None):
        fh = open('imagenet/caffe_ilsvrc12/val.txt', 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            words[0]=os.path.join(root, words[0])
            print('img path:',words[0],'label:',words[1])
            imgs.append((words[0], int(words[1])))
            self.imgs = imgs
            self.transforms = transform
            self.target_transform = target_transform

    def __getitem__(self, index):
      #print('index:',index)
      img_path,label = self.imgs[index]
        pil_img = Image.open(img_path).convert('L')
        if self.transforms:
            data = self.transforms(pil_img)
        else:
            pil_img = np.asarray(pil_img)
            data = torch.from_numpy(pil_img)
        return data,label
    def __len__(self):
        return len(self.imgs)

自定义的MyDataSet类继承于torch.utils.data.DataSet类。由于图片本身一部分是三通道的,一部分却是单通道的。因此如果不在读取的时候统一读入灰度图,就会报一个错误:

RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [1, 224, 224] at entry 25

我原本是想读入彩色图,以便在下面直接进行展示。这个错误我不知道如何解决,因此统一读入时使用

pil_img = Image.open(img_path).convert('L')

读取单通道图片,在后面的展示中显示的也就是灰度图片。

train_dataset = MyDataSet('imagenet/val')
print(len(train_dataset))
valid_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
for image in valid_loader:
    valid_image, valid_label = image[0], image[1]
    print('valid_label:', valid_label)
    print('valid_image shape', valid_image.shape)
    print(valid_image[0].shape)
    plt.imshow(valid_image[0].squeeze(), cmap='gray')
    plt.show()
    break
valid_label: tensor([658, 283, 202, 619,  32, 758, 646, 690, 100, 546, 942, 728, 343, 969,
         80, 530, 296, 412, 163, 128, 858, 702, 507, 500, 303, 478, 342,  10,
        524, 703, 277, 777, 600, 806, 768, 353, 718, 981, 598, 519, 413, 817,
        774, 302, 263, 366,  31, 600,  48, 986,  98, 602, 409,  39, 894, 747,
        200, 384, 140, 386, 191, 952, 128, 990])
valid_image shape torch.Size([64, 1, 224, 224])
torch.Size([1, 224, 224])

0人推荐
随时随地看视频
慕课网APP