项目地址
目的
根据 Feedback-Network (CVPR 2017, Zamir et al.) 论文提出的反馈网络结构,对CIFAR100或类似数据集进行分类。当前实现了CIFAR100数据集上的训练和测试,基本达到论文效果。
结果
Feedback-Network48,CIFAR100验证精度
val accuracy
Requirements
Pytorch = 0.3.1
python = 2.7
numpy >= 1.14.2
步骤
使用Pytorch为工具,实现CIFAR100数据集分类。
数据准备:
训练:将CIFAR100数据集放在
./data/
路径下。CIFAR100数据集下载测试:将训练好的模型放在
./models/
路径下。 baseline模型下载
训练步骤:
python FeedbackNet_train.py
每10个epoch后,训练好的模型将以
.pth
文件的形式保存在./models/
文件夹下。
运行
classifier_train.py
即可。
验证步骤:
python FeedbackNet_test.py
修改
classifier_test.py
文件相关参数,其中ckpt
表示模型加载位置,默认采用CIFAR100数据集中的test数据。然后运行
classifier_test.py
即可。在控制台输出验证结果。
方法
FeedbackNet:
以ConvLSTM为基础,实现网络结构。详细解读待更新。FeedbackNet
ConvLSTM with skip connections
训练代码流程
Hyper-params: 设置数据加载路径、模型保存路径、初始学习率等参数。
Training parameters: 用于定义模型训练中的相关参数,例如最大迭代次数、优化器、损失函数、是否使用GPU等、模型保存频率等
load data: 定义了用于读取数据,在其中实现了数据、标签读取及预处理过程。预处理过程在
__getitem__
中。models: 定义的FeedbackNet类,并实例化
optimizer、criterion、lr_scheduler: 定义优化器为SGD优化器,损失函数为CrossEntropyLoss,学习率调整策略采用ReduceLROnPlateau。
trainer: 定义了用于模型训练和验证的类Trainer,trainer为Trainer的实例化。在Trainer的构造函数中根据步骤二中的参数设定,对训练过程中的参数进行设置,包括训练数据、测试数据、模型、是否使用GPU等。
Trainer中定义了训练和测试函数,分别为train()
和_val_one_epoch()
。train()
函数中,根据设定的最大循环次数进行训练,每次循环调用_train_one_epoch()
函数进行单步训练。
测试代码流程
Test parameters: 用于定义模型测试中的相关参数
models: 定义的FeedbackNet类,并实例化
tester: 对测试类Tester实例化,Tester中主要进行模型加载函数与预测函数。
_load_ckpt()
函数加载模型;test()
函数进行预测,其中定义了对单张图片进行预处理的过程,并输出预测结果。
作者:Meng_Blog
链接:https://www.jianshu.com/p/939448669206