我正在尝试使用PyTorch实现图像分类器(CNN / ConvNet),在这里我想从csv文件中读取标签。我有4个不同的类别,一张图片可能属于多个类别。
我已经阅读了PyTorch教程,本斯坦福教程以及本教程,但都没有涉及我的具体情况。我设法建立了torch.utils.data.Dataset
该类的自定义函数,该函数对于从csv文件中读取标签(仅适用于二进制分类器)的效果很好。
这是torch.utils.data.Dataset
我到目前为止所使用的类的代码(与上面链接的第三个教程稍作修改):
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import numpy as np
import pandas as pd
class MyCustomDataset(data.Dataset):
# __init__ function is where the initial logic happens like reading a csv,
# assigning transforms etc.
def __init__(self, csv_path):
# Transforms
self.random_crop = transforms.RandomCrop(800)
self.to_tensor = transforms.ToTensor()
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# Second column is the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# Calculate len
self.data_len = len(self.data_info.index)
# __getitem__ function returns the data and labels. This function is
# called from dataloader like this
def __getitem__(self, index):
# Get image name from the pandas df
single_image_name = self.image_arr[index]
# Open image
img_as_img = Image.open(single_image_name)
img_cropped = self.random_crop(img_as_img)
img_as_tensor = self.to_tensor(img_cropped)
# Get label(class) of the image based on the cropped pandas column
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
具体来说,我正在尝试从具有以下结构的文件中读取标签:
我的具体问题是,我无法弄清楚如何在Dataset
班级中实现这一点。我想我在csv中的标签的(手动)分配与PyTorch如何读取它们之间缺少联系,因为我对框架不是很熟悉。
我非常感谢您提供有关如何使其正常工作的帮助,或者,如果确实有涉及此方面的示例,那么也将非常感谢您提供链接!
万千封印
相关分类