继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

快速入门深度学习(2)迁移学习

慕容森
关注TA
已关注
手记 358
粉丝 183
获赞 649

咱们继续入门课程系列,这次是关于迁移学习(Transfer Learning)的故事。

    有点神经网络基础的基础就会看得明白,上次的例子就是拿已经训练好的网络去用一下嘛,根本无关什么“深度学习”,完全可以把AlexNet看做一个分类器,把待分类的数据丢进去就可以了,这里还有一个GoogleNet的例子:http://ww2.mathworks.cn/help/nnet/examples/classify-image-using-googlenet.html

    这次咱们要“学习”一把了,针对特定的任务构造自己的分类器了。这次咱们仍然使用AlexNet的网络结构(谁让它经典呢),训练这个网络让它为咱们服务。

    在正式Coding之前,首先了解下什么是迁移学习。所谓的迁移学习就是指在深度学习中,把一个学习好的深度网络,稍加改造变成自己特有网络的意思,至于这样做的道理,咱们这里先不深入探讨,只要先记住迁移学习有个很大的好处,就是网络收敛速度快。

实验准备

Matlab2017b或者更新的版本,AlexNet。

数据准备:为了实验的一致性,使用Matlab计算机视觉工具箱自带的数据。

开始编程

载入数据

unzip('MerchData.zip');

imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');

[imdsTrain,idmsValidation] =splitEachLabel(imds,0.7,'randomized');

unzip函数的意思是解压压缩文件。执行这一句之后可以看到在当前目录下多了一个文件夹:


这个文件夹里面就是本次实验所使用的数据。为了更方便地组织该数据,我们使用imageDatastore函数来构造一个数据结构,用以管理数据。执行上面一句之后得到来一个imageDatastore数据结构,我们进入当前的工作空间对其进行观察。

可以看到待使用的数据,被一个数据结构进行了组织,并且使用文件夹的名称作为了类标签。我们随机选择16个图像用 的方式进行显示。

numImages= numel(imds.Labels);%统计总数

idx =randperm(numImages,16);  %随机选择

figure

for i = 1:16

    subplot(4,4,i)

    I = readimage(imds,idx(i));

    imshow(I)

end

可以看到

我们接下来把图像分为测试集(30%)和训练集(70%):

[imdsTrain,idmsValidation]= splitEachLabel(imds,0.7,'randomized');

数据准备完毕了。

加载AlexNet网络

由于我们这一章讲的是迁移学习,所以接下来需要加载已经训练好的alexnet网络。关于如何加载请参看前一章。

net =alexnet;

修改网络

由于咱们这次只需要识别5个类,所以需要对AlexNet网络进行修改以适应当前的问题。我们这次主要对其进行如下修改:修改全连接层的输出数量,从原来的1000变为5,其余保持不变。首先提取出前面的层数,然后使用fullyConnectedLayer构造全连接层,最后完成整个网络的构建。

layersTransfer= net.Layers(1:end-3);

 

layers =[

    layersTransfer

    fullyConnectedLayer(5,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)

    softmaxLayer

    classificationLayer];

最后的结果layer就是我们需要的网络结构,此时网络还未经训练。

训练网络

训练网络在Matlab中是一件非常简单的事情,我们只需要配置好训练参数就好了:

options = trainingOptions('sgdm',...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'ValidationData',idmsValidation, ...
    'ValidationFrequency',3, ...
    'ValidationPatience',Inf, ...
    'Verbose',false, ...
    'Plots','training-progress');

关于训练的参数,咱们以后再详细介绍,这里需要了解的一点就是,由于神经网络参数众多,而且是一个典型的非凸优化问题,所以,训练的参数选择相当重要。

netTransfer = trainNetwork(imdsTrain,layers,options);

运行完上面一句就可以得到netTransfer作为迁移网络。

验证网络

我们使用验证集去测试神经网络的有效性:

YPred = classify(netTransfer,idmsValidation);
accuracy = mean(YPred == idmsValidation.Labels)

结果表明我们的训练出来的神经网络具有良好的泛化性。

总结

从上面的编程过程中,可以发现Matlab神经网络工具箱已经帮助我们做好了很多工作,我们只需要去设计网络即可,然后训练即可,把广大程序员从无边无际的codeing中解放出来。

原文出处

打开App,阅读手记
1人推荐
发表评论
随时随地看视频慕课网APP