继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

深度神经网络优化策略之——残差学习

跃然一笑
关注TA
已关注
手记 303
粉丝 40
获赞 164

问题起源

深度学习普遍认为发端于2006年,根据Bengio的定义,深层网络由多层自适应非线性单元组成——即多层非线性模块的级联,所有层次上都包含可训练的参数,在工程实际操作中,深层神经网络通常是五层及以上,包含数百万个可学习的自由参数的庞然大物。理论上,网络模型无论深浅与否,都能通过函数逼近数据的内在关系和本质特征,但在解决真实世界的复杂问题时,需要指数增长的计算单元,浅层网络往往出现函数表达能力不足,而深层网络则可能仅仅需要较少的计算单元。
  不过网络并不是像理论上那样越深越好,除了显而易见的因为层数过多而导致浪费性质的占用显存和“吃”计算力的问题,还会出现以下三种问题。

  • 过拟合   (over fit)

  • 梯度弥散 (vanishing gradient problem)

  • 网络退化 (degenerate)

其中,问题一、二并不是本文所讲的残差学习主要要解决的问题,所以就不多赘述,只讲述网络退化的问题。其现象如下图所示,是随着网络层数的增多,整体模型的表达能力增强,但是训练精度反而变差,并且因为训练精度本身也下降的缘故,故而可以排除是过拟合的原因,而确定是网络退化。

When deeper networks are able to start converging, a degradation problem has been exposed: with the network depth increasing, accuracy gets saturated which might be unsurprising and then degrades rapidly. Unexpectedly, such degradation is not caused by overfitting, and adding more layers to a suitably deep model leads to higher training error,as reported in and thoroughly verified by our experiments.


https://img.mukewang.com/5d5743760001bd4a07270287.png

但是,很可惜的是,业界对于网络退化的原因及其标准情况依然没有定论,甚至说出现了随着网络变深而效果变差的问题的时候,也有可能无法分辨出是梯度弥散还是网络退化的问题。读者如果有兴趣,可以自行去寻找网络退化方面的研究论文,各家的观点虽然都不尽相同,但我们还是可以发现不少有用的信息。

残差学习

而对于上述问题,Kaiming He大神提出了一种简洁而不失优雅的残差学习的方法。多的不谈,我们直接甩出模型结构来讲解残差学习的思想。

https://img3.mukewang.com/5d5743a300011f5604810829.jpg

首先,只看图的左半边,也就是橘红色的部分。左侧与普通网络连接方式的区别一目了然——在顺次直连而下的基础上加入了每隔两层的跨接桥(其实官方的叫法并非如此,然而这么叫它显得更加直观)。不过纯凭看图的感觉毕竟流于表面,用数学说话才是严谨的态度。

对于一个神经网络而言,我们需要用反向传播来更新参数,就像这样:

https://img2.mukewang.com/5d5743d00001ef4902430140.png

此时,第二个式子所得的结果就是我们常说的梯度。

而当如下图网络越来越深的时候:

https://img2.mukewang.com/5d5743d7000133f703230248.png

......


这时候再通过算偏导求梯度,就会是这样:

https://img1.mukewang.com/5d57441c0001f78f05290091.png

其实数列的每一项都很小,再依此相乘就会越来越小,最后趋近于0,举个简单的例子就是0.9虽然很接近于1,但当有n个0.9相乘时(n趋近于无限大),最后的结果就会无限趋近于0。

而当有了“跨接桥”之后,我们再算偏导的时候就会变成这样:

https://img.mukewang.com/5d5743e20001a9c604820099.png

说白了就是1.01的n次方依然大于1。

最后,我们可以发现对于相同的数据集来讲,残差网络比同等深度的其他网络表现出了更好的性能。

https://img.mukewang.com/5d5744400001ce8507490485.png

https://img.mukewang.com/5d5744430001a78107560577.png

不过,这是大神的测试结果,没有什么说服力,而我在自己的项目里做了一组关于有无残差学习的对比,下面是数据图(项目是和图像增强有关,所以用PSNR作为评判标准):

https://img4.mukewang.com/5d574448000114a907550669.png

最后可见,Loss的下降趋势,残差学习的方法明显更加平稳,而最后结果Loss和PSNR虽然差距目测不大,但最后的图片视觉效果却千差万别。



作者:Eaton_Lee
链接:https://www.jianshu.com/p/c33acc52b4bc


打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP