继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

自己动手制作人工神经网络0x4:实际训练

潇潇雨雨
关注TA
已关注
手记 341
粉丝 25
获赞 130

接下来的内容,是关于如何实际应用之前编写的ANN,来完成手写数字识别的任务。

准备

首先,需要下载数据集,以用于训练和测试。这里使用缩小版的mnist数据集。训练集有100条数据,测试集有10条数据。大家可以去这里下载。
https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_test_10.csv
https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_train_100.csv

Coding

导入需要的库。

import matplotlib.pyplot as plt
%matplotlib inlineimport numpy as npimport scipy.specialfrom NN import NN

导入数据。

train_file = open("./mnist_train_100.csv", 'r')
train_list = train_file.readlines()
train_file.close()

test_file = open("./mnist_test_10.csv", 'r')
test_list = test_file.readlines()
test_file.close()

初始化模型。

input_nodes = 784
hidden_nodes = 100
output_nodes = 10

learning_rate = 0.3

nn = NN(input_nodes, hidden_nodes, output_nodes, learning_rate)

开始训练。epoch为迭代次数,因为数据集比较小,每轮迭代都会用上整个训练集。这里就只迭代一次好了。
数据集每一行为一条数据,第一个值是标签,也就是这条数据所代表的数字。接下来784(28*28)个值则是每一个像素点,范围是0~255。这里等比例缩小每个像素的值,把范围缩到0~0.99,再加上0.01。最终范围是0.01到1,避免了值为0的“死值”。
标签的值也不能太极端,我们对对应数字的节点的期望值为0.99,而其他节点的期望值为0.01。

epoch = 1for i in range(epoch):    for record in train_list:
        values = record.split(',')
        inputs = (np.asfarray(values[1:]) / 255.0 * 0.99) + 0.01
        labels = np.zeros(output_nodes) + 0.01
        labels[int(values[0])] = 0.99

        nn.train(inputs, labels)

这里大家可以尝试使用完整的mnist数据集来训练。也可以尝试更改迭代次数。



作者:御史神风
链接:https://www.jianshu.com/p/9842c03dc72b

打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP