在卷积神经网络中设置层的维度

假设我有 4 个批次的 3x100x100 图像作为输入,并且我正在尝试使用 pytorch 制作我的第一个卷积神经网络。我真的不确定我的卷积神经网络是否正确,因为当我通过以下安排训练我的输入时,我遇到了错误:


Expected input batch_size (1) to match target batch_size (4).


以下是我的转发nnet:


那么如果我要通过它:


nn.Conv2d(3, 6, 5)

我会得到 6 层地图,每层都有尺寸(100-5+1)。


那么如果我要通过它:


nn.MaxPool2d(2, 2)

我会得到 6 层地图,每层都有尺寸 (96/2)


然后,如果我要通过它:


nn.Conv2d(6, 16, 5)

我会得到 16 层地图,每层都有尺寸 (48-5+1)


那么如果我要通过它:


self.fc1 = nn.Linear(44*44*16, 120)

我会得到 120 个神经元


那么如果我要通过它:


self.fc2 = nn.Linear(120, 84)

我会得到 84 个神经元


那么如果我要通过它:


self.fc3 = nn.Linear(84, 3)

我会得到 3 个输出,这将是完美的,因为我有 3 类标签。但正如我之前所说,这会导致一个非常令人惊讶的错误,因为这对我来说很有意义。


完整的神经网络代码:


import torch.nn as nn

import torch.nn.functional as F



class Net(nn.Module):

    def __init__(self):

        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, 5)

        self.pool = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(6, 16, 5)

        self.fc1 = nn.Linear(44*44*16, 120)

        self.fc2 = nn.Linear(120, 84)

        self.fc3 = nn.Linear(84, 3)


    def forward(self, x):

        x = self.pool(F.relu(self.conv1(x)))

        x = self.pool(F.relu(self.conv2(x)))

        x = x.view(-1, 16 *44*44)

        x = F.relu(self.fc1(x))

        x = F.relu(self.fc2(x))

        x = self.fc3(x)

        return x



net = Net()

net.to(device)


桃花长相依
浏览 307回答 1
1回答

慕村225694

你的理解是正确的,非常详细。但是,您使用了两个池化层(请参阅下面的相关代码)。所以第二步之后的输出将是16个44/2=22维度的地图。x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))要解决此问题,要么不池化,要么将全连接层的维度更改为22*22*16。要通过不池化来修复,请修改您的转发功能,如下所示。def forward(self, x):    x = self.pool(F.relu(self.conv1(x)))    x = F.relu(self.conv2(x))    x = x.view(-1, 16 *44*44)    x = F.relu(self.fc1(x))    x = F.relu(self.fc2(x))    x = self.fc3(x)    return x要通过更改全连接层的维度来修复,请更改网络的声明如下。def __init__(self):    super(Net, self).__init__()    self.conv1 = nn.Conv2d(3, 6, 5)    self.pool = nn.MaxPool2d(2, 2)    self.conv2 = nn.Conv2d(6, 16, 5)    self.fc1 = nn.Linear(22*22*16, 120)    self.fc2 = nn.Linear(120, 84)    self.fc3 = nn.Linear(84, 10)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python