猿问

为了通过神经网络运行数据集,我需要做哪些转换?

我是深度学习和 Pytorch 的新手,但我希望有人能帮我解决这个问题。我的数据集包含不同大小的图像。我正在尝试创建一个可以对图像进行分类的简单神经网络。但是,我遇到了不匹配错误。


神经网络


class Net(nn.Module):

    def __init__(self):

        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 3)

        self.conv2 = nn.Conv2d(32, 32, 3)

        self.fc1 = nn.Linear(32 * 3 * 3, 200)

        self.fc2 = nn.Linear(200, 120)


    def forward(self, x):

        x = F.relu(self.conv1(x))

        x = F.relu(self.conv2(x))

        x = F.relu(self.fc1(x))

        x = self.fc2(x)

        return x

net = Net()

我的第一个卷积层有 1 个输入通道,因为我将图像转换为灰度图像。32 个输出通道是一个任意决定。最后的全连接层有 120 个输出通道,因为有 120 个不同的类。


确定转换并分配训练集和验证集


transform = transforms.Compose(

    [transforms.Grayscale(1),

     transforms.RandomCrop((32,32)),

     transforms.ToTensor(),

     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


data_dir = 'dataset'

full_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform = transform)


train_size = int(0.8 * len(full_dataset))

val_size = len(full_dataset) - train_size

trainset, valset = torch.utils.data.random_split(full_dataset, [train_size, val_size])


trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,

                                           shuffle=True, num_workers=2)

valloader = torch.utils.data.DataLoader(valset, batch_size=4,

                                           shuffle=False, num_workers=2)

classes = full_dataset.classes

我将图像转换为灰度,因为它们无论如何都是灰色的。我将图像裁剪为 32,因为图像具有不同的尺寸,并且我认为它们在通过神经网络时必须具有相同的尺寸。到目前为止一切正常。


训练神经网络


for epoch in range(2):  # loop over the dataset multiple times


    running_loss = 0.0

    for i, data in enumerate(trainloader, 0):

        # get the inputs

        inputs, labels = data


        # zero the parameter gradients

        optimizer.zero_grad()


        # forward + backward + optimize

        outputs = net(inputs)

        loss = criterion(outputs, labels)

        loss.backward()

        optimizer.step()



我的代码是此 Pytorch 教程中提供的代码的变体。有人能告诉我我做错了什么吗?


慕桂英3389331
浏览 152回答 1
1回答
随时随地看视频慕课网APP

相关分类

Python
我要回答