本篇更多的是在代码实战方向,不会涉及太多的理论。本文主要针对TensorFlow和卷积神经网络有一定基础的同学,并对图像处理有一定的了解。
阅读本文你大概需要以下知识:
1.TensorFlow基础
2.TensorFlow实现卷积神经网络的前向传播过程
3.TFRecord数据格式
4.Dataset的使用
5.Slim的使用
好了废话不多说,下面开始。
一.数据准备
首先我们需要有一个让我们训练的数据集,这里谷歌已经帮我们做好了。这里要把数据集下载下来,打开命令行,执行如下命令:
wget http://download.tensorflow.org/example_image/flower_photo.tgz//解压tar xzf flower_photos.tgz
这里需要注意的是,文件最好是下载到你的工程目录下方便你的读取。什么?你还不会搭建TensorFlow程序?请移步https://www.tensorflow.org/install/
选择自己的操作系统,在这里我的是macOS。我使用的是Virtualenv来搭建TensorFlow运行环境。
数据集下载并解压后,我们可以看到大概是这个样子

每一个文件夹里都是一个种类的花的图片,这里总共有五种花。
好了,数据有了?接下来该怎么办呢?当然是把数据进行预处理拉,你不会觉得我们的TensorFlow可以直接识别这些图片进行训练吧,hhhhhh。
二.数据预处理
接下来我们在目录下新建pre_data.python文件。TensorFlow对图片做处理一般是生成TFRecord文件。什么是TFRecord?后面我们会讲到。
首先我们要引入我们需要的库。
# glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list)import glob#os.path生成路径方便glob获取import os.path#这里主要用到随机数import numpy as np#引入tensorflow框架import tensorflow as tf#引入gflie对图片做处理from tensorflow.python.platform import gfile
相关库在我们这个程序中的功能都作了简单介绍,下面用到的时候我们会更加详细的说明。
大家都知道我们的数据集一般分训练,测试和验证数据集。观察上面的数据集,谷歌只是给出了每一种花的图片,并没有给去哪些我训练,哪些是测试,哪些是验证数据集。所以在这里我们要进行划分。
#输入图片地址INPUT_DATA = '../../flower_photos'#训练数据集OUTPUT_FILE = './path/to/output.tfrecords'#测试数据集OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'#验证数据集OUTPUT_VALIDATION_FILE = './path/to/output_validation.tfrecords'#测试数据和验证数据的比例VALIDATION_PERCENTAGE = 10 TEST_PERCENTAGE = 10
关于VALIDATION_PERCENTAGE和TEST_PERCENTAGE这两个常量,我们在后面的例子会给出。
下面我们就来定义处理数据的方法:
def create_image_lists(sess,testing_percentage,validation_percentage):
#拿到INPUT_DATA文件夹下的所有目录(包括root)
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] #如果是root_dir不需要做处理
is_root_dir = True
#定义图片对应的标签,从0-4分别代表不同的花
current_label = 0
#写入TFRecord的数据需要首先定义writer
#这里定义三个writer分别存储训练,测试和验证数据
writer = tf.python_io.TFRecordWriter(OUTPUT_FILE)
writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE)
writer_validation = tf.python_io.TFRecordWriter(OUTPUT_VALIDATION_FILE) #循环目录
for sub_dir in sub_dirs: if is_root_dir: #跳过根目录
is_root_dir = False
continue
#定义空数组来装图片路径
file_list = [] #生成查找路径
dir_name = os.path.basename(sub_dir)
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "jpg") # extend合并两个数组
# glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list)
# 比如:glob.glob(r’c:*.txt’) 这里就是获得C盘下的所有txt文件
file_list.extend(glob.glob(file_glob)) #路径下没有文件就跳过,不继续操作
if not file_list: continue
#这里我定义index来打印当前进度
index = 0
#file_list此时是图片路径列表
for file_name in file_list: #使用gfile从路径中读取图片
image_raw_data = gfile.FastGFile(file_name, 'rb').read() #对图像解码,解码结果为一个张量
image = tf.image.decode_jpeg(image_raw_data) #对图像矩阵进行归一化处理
#因为为了将图片数据能够保存到 TFRecord 结构体中
#所以需要将其图片矩阵转换成 string
#所以为了在使用时能够转换回来
#这里确定下数据格式为 tf.float32
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32) # 将图片转化成299*299方便模型处理
image = tf.image.resize_images(image, [299, 299]) #为了拿到图片的真实数据这里我们要运行一个session op
image_value = sess.run(image)
pixels = image_value.shape[1] #存储在TFrecord里面的不能是array的形式
#所以我们需要利用tostring()将上面的矩阵
#转化成字符串
#再通过tf.train.BytesList转化成可以存储的形式
image_raw = image_value.tostring() #存到features
#随机划分测试集和训练集
#这里存入TFRecord三个数据,图像的pixels像素
#图像原张量,这里我们需要转成string
#以及当前图像对应的标签
example = tf.train.Example(features=tf.train.Features(feature={ 'pixels': _int64_feature(pixels), 'label': _int64_feature(current_label), 'image_raw': _bytes_feature(image_raw)
}))
chance = np.random.randint(100) #随机划分数据集
if chance < validation_percentage:
writer_validation.write(example.SerializeToString()) elif chance < (testing_percentage+validation_percentage):
writer_test.write(example.SerializeToString()) else:
writer.write(example.SerializeToString()) # print('example',index)
index = index + 1
#每一个文件夹下的所有图片都是一个类别
#所以这里每遍历完一个文件夹,标签就增加1
current_label += 1
writer.close()
writer_validation.close()
writer_test.close()运行上述程序需要一定时间,我的电脑比较烂,大概跑了三十分钟左右。这时候在你的./path/to目录下可以看到output.tfrecords,output_test.tfrecords,output_validation.tfrecords三个文件,分别存放了训练,测试和验证数据集。上述代码将所有图片划分成训练、验证和测试数据集。并且把图片从原始的jpg格式转换成inception-v3模型需要的299 * 299 * 3的数字矩阵。在数据处理完毕之后,通过以下命令可以下载谷歌提供好的Inception_v3模型。
wget http://download.tensorflow.org/models/inception_v3_2016_08_26.tar.gz//解压之后可以得到训练好的模型文件inception_v3.ckpttar xzf inception_v3_2016_08
二.训练
当新的数据集和已经训练好的模型都准备好之后,我们来写代码在谷歌inception_v3的基础上训练新数据集。
首先同样我们导入相关的库并且定义相关常量。在这里我们通过slim工具来直接加载模型,而不用自己再定义前向传播过程。
import numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slim# 加载通过TensorFlow-Silm定义好的 inception_v3模型import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3# 输入数据文件INPUT_DATA = './path/to/output.tfrecords'# 验证数据集VALIDATION_DATA = './path/to/output_validation.tfrecords'# 保存训练好的模型的路径ls = './path/to/save_model'# 谷歌提供的训练好的模型文件地址CKPT_FILE = './path/to/inception_v3.ckpt'TRAIN_FILE = './path/to/save_model'# 定义训练中使用的参数LEARNING_RATE = 0.01#组合batch的大小BATCH = 32#用于one_hot函数输出概率分布N_CLASSES = 5#打乱顺序,并设置出队和入队中元素最少的个数,这里是10000个shuffle_buffer = 10000# 不需要从谷歌模型中加载的参数,这里就是最后的全连接层。因为输出类别不一样,所以最后全连接层的参数也不一样CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'# 需要训练的网络层参数 这里就是最后的全连接层TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
接下来我们定义几个辅助方法。首先因为我们的数据存在TFRecord里,需要定义方法从TFRecord解析数据。
def parse(record):
features = tf.parse_single_example(
record,
features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), 'pixels': tf.FixedLenFeature([], tf.int64)
}
) #decode_raw用于解析TFRecord里面的字符串
decoded_image = tf.decode_raw(features['image_raw'], tf.uint8)
label = features['label'] #要注意这里的decoded_image并不能直接进行reshape操作
#之前我们在存储的时候,把图片进行了tostring()操作
#这会导致图片的长度在原来基础上*8
#后面我们要用到numpy的fromstring来处理
return decoded_image, label接下来定义两个方法。因为我们已经下载了谷歌训练好的inception_v3模型的参数,下面我们需要定义两个方法从里面加载参数。
#直接从inception_v3.ckpt中读取的参数def get_tuned_variables():
#strip删除头尾字符,默认为空格
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
variables_to_restore = [] #这里给出了所有slim模型下的参数
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions: if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var) return variables_to_restore#需要重新训练的参数def get_trainable_variables():
#strip删除头尾字符,默认为空格
scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
variables_to_train = [] # 枚举所有需要训练的参数前缀,并通过这些前缀找到所有的参数。
for scope in scopes: #从TRAINABLE_VARIABLES集合中获取名为scope的变量
#也就是我们需要重新训练的参数
variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables) return variables_to_train这里我们就写完了所需要的工具函数,接下来我们定义主函数。主函数主要完成数据读取,模型定义,通过模型得出前向传播结果,通过损失函数计算损失,最后把损失交给优化器做处理。首先我们先来完成数据读取的代码,这里我们使用的是TensorFlow高层API Dataset。不清楚的可以去看一下Dataset的用法。
这里我们在训练的同时也对模型做了验证。所以我们需要加载训练和验证数据
#读取测试数据 #利用TFRecordDataset读取TFRecord文件 dataset = tf.data.TFRecordDataset([INPUT_DATA]) #解析TFRecord dataset = dataset.map(parse) #把数据打乱顺序并组装成batch dataset = dataset.shuffle(shuffle_buffer).batch(BATCH) #定义数据重复的次数 NUM_EPOCHS = 10 dataset = dataset.repeat(NUM_EPOCHS) #定义迭代器来获取处理后的数据 iterator = dataset.make_one_shot_iterator() #迭代器开始迭代 img, label = iterator.get_next() #读取验证数据(同上) valida_dataset = tf.data.TFRecordDataset([VALIDATION_DATA]) valida_dataset = valida_dataset.map(parse) valida_dataset = valida_dataset.batch(BATCH) valida_iterator = valida_dataset.make_one_shot_iterator() valida_img,valida_label = valida_iterator.get_next() #定义inception-v3的输入,images为输入图片,label为每一张图片对应的标签 #再解释下每一个维度 None为batch的大小,299为图片大小,3为通道 images = tf.placeholder(tf.float32,[None,299,299,3],name='input_images') labels = tf.placeholder(tf.int64,[None],name='labels')
要注意上述定义的只是tensorflow的张量,保存的只是计算过程并没有具体的数据。只有运行session之后才会拿到具体的数据。
下面我们来通过slim加载inception-v3模型
#定义inception-v3模型结构 inception_v3.ckpt里只有参数的取值
with slim.arg_scope(inception_v3.inception_v3_arg_scope()): #logits inception_v3前向传播得到的结果
logits,_ = inception_v3.inception_v3(images,num_classes=N_CLASSES) #获取需要训练的变量
trainable_variables = get_trainable_variables() #这里用交叉熵作为损失函数,注意一下tf.losses.softmax_cross_entropy的参数
# tf.losses.softmax_cross_entropy(
# onehot_labels, # 注意此处参数名就叫 onehot_labels
# logits,
# weights=1.0,
# label_smoothing=0,
# scope=None,
# loss_collection=tf.GraphKeys.LOSSES,
# reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
# )
#这里要把labels转成one_hot类型,logits就是神经网络的输出
tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0) #把计算的损失交给优化器处理
train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss()) #计算正确率。
with tf.name_scope('evaluation'):
correct_prediction = tf.equal(tf.argmax(logits,1),labels)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #定义加载模型的函数
load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True) #定义保存新的训练好的模型的函数
saver = tf.train.Saver() with tf.Session() as sess: #初始化所有变量
init = tf.global_variables_initializer()
sess.run(init)
print('Loading tuned variables from %s'%CKPT_FILE) #加载谷歌已经训练好的模型
load_fn(sess)
step = 0; #在这里我们用一个while来循环训练,直到dataset里没有数据就结束循环
while True: try: if step % 30 == 0 or step + 1 == STEPS: #每30轮输出一次正确率
if step != 0: #每30轮保存一次当前模型的参数,以便中途训练中断可以继续
saver.save(sess,TRAIN_FILE,global_step=step) #运行session拿到真实图片的数据
valida_img_batch,valida_label_batch = sess.run([valida_img,valida_label]) #上面有提到TFRecord里图片数据被转成了string,在这里转回来
valida_img_batch = np.fromstring(valida_img_batch, dtype=np.float32) #把图片张量拉成新的维度
valida_img_batch = tf.reshape(valida_img_batch, [32, 299, 299, 3]) #用session运行上述操作,得到处理后的图片张量
valida_img_batch = sess.run(valida_img_batch) #把图片张量传到feed_dict算出正确率并显示
validation_accuracy = sess.run(evaluation_step,feed_dict={
images:valida_img_batch,
labels:valida_label_batch
})
print('Step %d: Validation accurary = %.1f%%'%(step,validation_accuracy*100.0)) #下面是对训练数据的操作,同上
img_batch,label_batch = sess.run([img,label])
img_batch = np.fromstring(img_batch, dtype=np.float32)
img_batch = tf.reshape(img_batch, [32,299, 299, 3])
img_batch = sess.run(img_batch)
sess.run(train_step,feed_dict={
images:img_batch,
labels:label_batch
}) #step仅仅用于记录
step = step + 1
except tf.errors.OutOfRangeError: break运行上述程序开始训练。在这里我暂时是使用cpu进行训练,训练过程大约3小时,可以得到类型下面的结果。
step 0:Validation accuracy = 12.5% step 30:Validation accuracy = 22.2% step 60:Validation accuracy = 63.2% step 90:Validation accuracy = 79.8% step 120:Validation accuracy = 86.4% step 150:Validation accuracy = 88.5% .....
作者:sidiWang
链接:https://www.jianshu.com/p/fc77879d3591
随时随地看视频