Pytorch 使用知识转移保存和加载 VGG16

我使用以下语句保存了一个带有知识转移的 VGG16:


torch.save(model.state_dict(), 'checkpoint.pth')

并使用以下语句重新加载:


state_dict = torch.load('checkpoint.pth') model.load_state_dict(state_dict)


只要我重新加载 VGG16 模型并使用以下代码为其提供与以前相同的设置,就可以工作:


model = models.vgg16(pretrained=True)

model.cuda()

for param in model.parameters(): param.requires_grad = False


class Network(nn.Module):

    def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):


#             input_size: integer, size of the input

#             output_size: integer, size of the output layer

#             hidden_layers: list of integers, the sizes of the hidden layers

#             drop_p: float between 0 and 1, dropout probability


        super().__init__()

        # Add the first layer, input to a hidden layer

        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])


        # Add a variable number of more hidden layers

        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])

        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])

        self.output = nn.Linear(hidden_layers[-1], output_size)

        self.dropout = nn.Dropout(p=drop_p)


    def forward(self, x):

        ''' Forward pass through the network, returns the output logits '''


        # Forward through each layer in `hidden_layers`, with ReLU activation and dropout

        for linear in self.hidden_layers:

            x = F.relu(linear(x))

            x = self.dropout(x)


        x = self.output(x)

        return F.log_softmax(x, dim=1)


classifier = Network(25088, 102, [4096], drop_p=0.5)

model.classifier = classifier

如何避免这种情况?如何重新加载模型而不必重新加载 VGG16 并重新定义分类器?


拉莫斯之舞
浏览 275回答 1
1回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python