继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

机器学习:如何找到最优学习率

大话西游666
关注TA
已关注
手记 209
粉丝 140
获赞 620

原文见点击打开链接


学习率的重要性


目前深度学习使用的都是非常简单的一阶收敛算法,梯度下降法,不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形,所以初始学习率对深层网络的收敛起着决定性的作用,下面就是梯度下降法的公式

[Math Processing Error]w:=w−α∂∂wloss(w)

这里[Math Processing Error]α就是学习率,如果学习率太小,会导致网络loss下降非常慢,如果学习率太大,那么参数更新的幅度就非常大,就会导致网络收敛到局部最优点,或者loss直接开始增加,如下图所示。



学习率的选择策略在网络的训练过程中是不断在变化的,在刚开始的时候,参数比较随机,所以我们应该选择相对较大的学习率,这样loss下降更快;当训练一段时间之后,参数的更新就应该有更小的幅度,所以学习率一般会做衰减,衰减的方式也非常多,比如到一定的步数将学习率乘上0.1,也有指数衰减等。

这里我们关心的一个问题是初始学习率如何确定,当然有很多办法,一个比较笨的方法就是从0.0001开始尝试,然后用0.001,每个量级的学习率都去跑一下网络,然后观察一下loss的情况,选择一个相对合理的学习率,但是这种方法太耗时间了,能不能有一个更简单有效的办法呢?

一个简单的办法

Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”中的3.3节描述了一个非常棒的方法来找初始学习率,同时推荐大家去看看这篇论文,有一些非常启发性的学习率设置想法。

这个方法在论文中是用来估计网络允许的最小学习率和最大学习率,我们也可以用来找我们的最优初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。

下面就是随着迭代次数的增加,学习率不断增加的曲线,以及不同的学习率对应的loss的曲线。

https://img.mukewang.com/5b125c6f000148cb03890266.jpg



从上面的图片可以看到,随着学习率由小不断变大的过程,网络的loss也会从一个相对大的位置变到一个较小的位置,同时又会增大,这也就对应于我们说的学习率太小,loss下降太慢,学习率太大,loss有可能反而增大的情况。从上面的图中我们就能够找到一个相对合理的初始学习率,0.1。

之所以上面的方法可以work,因为小的学习率对参数更新的影响相对于大的学习率来讲是非常小的,比如第一次迭代的时候学习率是1e-5,参数进行了更新,然后进入第二次迭代,学习率变成了5e-5,参数又进行了更新,那么这一次参数的更新可以看作是在最原始的参数上进行的,而之后的学习率更大,参数的更新幅度相对于前面来讲会更大,所以都可以看作是在原始的参数上进行更新的。正是因为这个原因,学习率设置要从小变到大,而如果学习率设置反过来,从大变到小,那么loss曲线就完全没有意义了。


实现


上面已经说明了算法的思想,说白了其实是非常简单的,就是不断地迭代,每次迭代学习率都不同,同时记录下来所有的loss,绘制成曲线就可以了。下面就是使用PyTorch实现的代码,因为在网络的迭代过程中学习率会不断地变化,而PyTorch的optim里面并没有把learning rate的接口暴露出来,导致显示修改学习率非常麻烦,所以我重新写了一个更加高层的包mxtorch,借鉴了gluon的一些优点,在定义层的时候暴露初始化方法,支持tensorboard,同时增加了大量的model zoo,包括inceptionresnetv2,resnext等等,提供预训练权重。

下面就是部分代码,这里使用的数据集是kaggle上的dog breed,使用预训练的resnet50,ScheduledOptim的源码如下:

class ScheduledOptim(object):
    '''A wrapper class for learning rate scheduling'''

    def __init__(self, optimizer):
        self.optimizer = optimizer
        self.lr = self.optimizer.param_groups[0]['lr']
        self.current_steps = 0

    def step(self):
        "Step by the inner optimizer"
        self.current_steps += 1
        self.optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self.optimizer.zero_grad()

    def set_learning_rate(self, lr):
        self.lr = lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    @property
    def learning_rate(self):
        return self.lr


def find_lr():
    pass


def train_model():
    pass


整体代码如下

  1. criterion = torch.nn.CrossEntropyLoss()  

  2. net = model_zoo.resnet50(pretrained=True)  

  3. net.fc = nn.Linear(2048, 120)  

  4. with torch.cuda.device(0):  

  5.     net = net.cuda()  

  6. basic_optim = torch.optim.SGD(net.parameters(), lr=1e-5)  

  7. optimizer = ScheduledOptim(basic_optim)  

  8. lr_mult = (1 / 1e-5) ** (1 / 100)  

  9. lr = []  

  10. losses = []  

  11. best_loss = 1e9  

  12. for data, label in train_data:  

  13.     with torch.cuda.device(0):  

  14.         data = Variable(data.cuda())  

  15.         label = Variable(label.cuda())  

  16.     # forward  

  17.     out = net(data)  

  18.     loss = criterion(out, label)  

  19.     # backward  

  20.     optimizer.zero_grad()  

  21.     loss.backward()  

  22.     optimizer.step()  

  23.     lr.append(optimizer.learning_rate)  

  24.     losses.append(loss.data[0])  

  25.     optimizer.set_learning_rate(optimizer.learning_rate * lr_mult)  

  26.     if loss.data[0] < best_loss:  

  27.         best_loss = loss.data[0]  

  28.     if loss.data[0] > 4 * best_loss or optimizer.learning_rate > 1.:  

  29.         break  

  30. plt.figure()  

  31. plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]), (1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1))  

  32. plt.xlabel('learning rate')  

  33. plt.ylabel('loss')  

  34. plt.plot(np.log(lr), losses)  

  35. plt.show()  

  36. plt.figure()  

  37. plt.xlabel('num iterations')  

  38. plt.ylabel('learning rate')  

  39. plt.plot(lr)  

原文出处

打开App,阅读手记
1人推荐
发表评论
随时随地看视频慕课网APP