import torch
1. Requires_grad
但是,模型毕竟不是人,它的智力水平还不足够去自主辨识那些量的梯度需要计算,既然如此,就需要手动对其进行标记。
在PyTorch中,通用的数据结构tensor包含一个attributerequires_grad
,它被用于说明当前量是否需要在计算中保留对应的梯度信息,以上文所述的线性回归为例,容易知道参数www为需要训练的对象,为了得到最合适的参数值,我们需要设置一个相关的损失函数,根据梯度回传的思路进行训练。
官方文档中的说明如下
If there’s a single input to an operation that requires gradient, its output will also require gradient.
只要某一个输入需要相关梯度值,则输出也需要保存相关梯度信息,这样就保证了这个输入的梯度回传。
而反之,若所有的输入都不需要保存梯度,那么输出的requires_grad
会自动设置为False。既然没有了相关的梯度值,自然进行反向传播时会将这部分子图从计算中剔除。
Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.
对于那些要求梯度的tensor,PyTorch会存储他们相关梯度信息和产生他们的操作,这产生额外内存消耗,为了优化内存使用,默认产生的tensor是不需要梯度的。
而我们在使用神经网络时,这些全连接层卷积层等结构的参数都是默认需要梯度的。
a = torch.tensor([1., 2., 3.])
print('a:', a.requires_grad)
b = torch.tensor([1., 4., 2.], requires_grad = True)
print('b:', b.requires_grad)
print('sum of a and b:', (a+b).requires_grad)
a: False
b: True
sum of a and b: True
2. Computation Graph
从PyTorch的设计原理上来说,在每次进行前向计算得到pred时,会产生一个用于梯度回传的计算图,这张图储存了进行back propagation需要的中间结果,当调用了.backward()后,会从内存中将这张图进行释放
这张计算图保存了计算的相关历史和提取计算所需的所有信息,以output作为root节点,以input和所有的参数为leaf节点,
we only retain the grad of the leaf node with requires_grad =True
在完成了前向计算的同时,PyTorch也获得了一张由计算梯度所需要的函数所组成的图
而从数据集中获得的input其requires_grad
为False,故我们只会保存参数的梯度,进一步据此进行参数优化
在PyTorch中,multi-task任务一个标准的train from scratch流程为
for idx, data in enumerate(train_loader):
xs, ys = data
optmizer.zero_grad()
# 计算d(l1)/d(x)
pred1 = model1(xs) #生成graph1
loss = loss_fn1(pred1, ys)
loss.backward() #释放graph1
# 计算d(l2)/d(x)
pred2 = model2(xs)#生成graph2
loss2 = loss_fn2(pred2, ys)
loss.backward() #释放graph2
# 使用d(l1)/d(x)+d(l2)/d(x)进行优化
optmizer.step()
Computation Graph本质上是一个operation的图,所有的节点都是一个operation,而进行相应计算的参数则以叶节点的形式进行输入
借助torchviz库以下面的模型作为示例
import torch.nn.functional as F
import torch.nn as nn
class Conv_Classifier(nn.Module):
def __init__(self):
super(Conv_Classifier, self).__init__()
self.conv1 = nn.Conv2d(1, 5, 5)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(5, 16, 5)
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 20)
self.fc2 = nn.Linear(20, 10)
def forward(self, x):
x = F.relu(self.pool1((self.conv1(x))))
x = F.relu(self.pool2((self.conv2(x))))
x = F.dropout2d(x, training=self.training)
x = x.view(-1, 256)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
Mnist_Classifier = Conv_Classifier()
from torchviz import make_dot
input_sample = torch.rand((1, 1, 28, 28))
make_dot(Mnist_Classifier(input_sample), params=dict(Mnist_Classifier.named_parameters()))
其对应的计算梯度所需的图(计算图)为
可以看到,所有的叶子节点对应的操作都被记录,以便之后的梯度回传。