带有PyTorch的多标签,多类别图像分类器(ConvNet)

我正在尝试使用PyTorch实现图像分类器(CNN / ConvNet),在这里我想从csv文件中读取标签。我有4个不同的类别,一张图片可能属于多个类别。

我已经阅读了PyTorch教程,本斯坦福教程以及本教程,但都没有涉及我的具体情况。我设法建立了torch.utils.data.Dataset该类的自定义函数,该函数对于从csv文件中读取标签(仅适用于二进制分类器)的效果很好。

这是torch.utils.data.Dataset我到目前为止所使用的类的代码(与上面链接的第三个教程稍作修改):

import torch

import torchvision.transforms as transforms

import torch.utils.data as data

from PIL import Image

import numpy as np

import pandas as pd



class MyCustomDataset(data.Dataset):

# __init__ function is where the initial logic happens like reading a csv,

# assigning transforms etc.

def __init__(self, csv_path):

    # Transforms

    self.random_crop = transforms.RandomCrop(800)

    self.to_tensor = transforms.ToTensor()

    # Read the csv file

    self.data_info = pd.read_csv(csv_path, header=None)

    # First column contains the image paths

    self.image_arr = np.asarray(self.data_info.iloc[:, 0])

    # Second column is the labels

    self.label_arr = np.asarray(self.data_info.iloc[:, 1])

    # Calculate len

    self.data_len = len(self.data_info.index)



# __getitem__ function returns the data and labels. This function is

# called from dataloader like this

def __getitem__(self, index):

    # Get image name from the pandas df

    single_image_name = self.image_arr[index]

    # Open image

    img_as_img = Image.open(single_image_name)

    img_cropped = self.random_crop(img_as_img)

    img_as_tensor = self.to_tensor(img_cropped)


    # Get label(class) of the image based on the cropped pandas column

    single_image_label = self.label_arr[index]


    return (img_as_tensor, single_image_label)


def __len__(self):

    return self.data_len

具体来说,我正在尝试从具有以下结构的文件中读取标签:

http://img3.mukewang.com/6075361000017ae307090294.jpg

我的具体问题是,我无法弄清楚如何在Dataset班级中实现这一点。我想我在csv中的标签的(手动)分配与PyTorch如何读取它们之间缺少联系,因为我对框架不是很熟悉。
我非常感谢您提供有关如何使其正常工作的帮助,或者,如果确实有涉及此方面的示例,那么也将非常感谢您提供链接!

慕雪6442864
浏览 1117回答 1
1回答

万千封印

也许我失去了一些东西,但如果你想你的列转换1..N(N = 4这里)成标签载体或形状(N,)(例如给您的数据。例如,label(img1) = [0, 0, 0, 1],label(img3) = [1, 0, 1, 0],...),为什么不:将所有标签列读入self.label_arr:self.label_arr = np.asarray(self.data_info.iloc[:, 1:]) # columns 1 to N相应地返回中的标签__getitem__()(此处不变):single_image_label = self.label_arr[index]为了训练您的分类器,您可以计算例如(N,)预测和目标标签之间的交叉熵。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python