本篇更多的是在代码实战方向,不会涉及太多的理论。本文主要针对TensorFlow和卷积神经网络有一定基础的同学,并对图像处理有一定的了解。
阅读本文你大概需要以下知识:
1.TensorFlow基础
2.TensorFlow实现卷积神经网络的前向传播过程
3.TFRecord数据格式
4.Dataset的使用
5.Slim的使用
好了废话不多说,下面开始。
一.数据准备
首先我们需要有一个让我们训练的数据集,这里谷歌已经帮我们做好了。这里要把数据集下载下来,打开命令行,执行如下命令:
wget http://download.tensorflow.org/example_image/flower_photo.tgz//解压tar xzf flower_photos.tgz
这里需要注意的是,文件最好是下载到你的工程目录下方便你的读取。什么?你还不会搭建TensorFlow程序?请移步https://www.tensorflow.org/install/
选择自己的操作系统,在这里我的是macOS。我使用的是Virtualenv来搭建TensorFlow运行环境。
数据集下载并解压后,我们可以看到大概是这个样子
每一个文件夹里都是一个种类的花的图片,这里总共有五种花。
好了,数据有了?接下来该怎么办呢?当然是把数据进行预处理拉,你不会觉得我们的TensorFlow可以直接识别这些图片进行训练吧,hhhhhh。
二.数据预处理
接下来我们在目录下新建pre_data.python文件。TensorFlow对图片做处理一般是生成TFRecord文件。什么是TFRecord?后面我们会讲到。
首先我们要引入我们需要的库。
# glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list)import glob#os.path生成路径方便glob获取import os.path#这里主要用到随机数import numpy as np#引入tensorflow框架import tensorflow as tf#引入gflie对图片做处理from tensorflow.python.platform import gfile
相关库在我们这个程序中的功能都作了简单介绍,下面用到的时候我们会更加详细的说明。
大家都知道我们的数据集一般分训练,测试和验证数据集。观察上面的数据集,谷歌只是给出了每一种花的图片,并没有给去哪些我训练,哪些是测试,哪些是验证数据集。所以在这里我们要进行划分。
#输入图片地址INPUT_DATA = '../../flower_photos'#训练数据集OUTPUT_FILE = './path/to/output.tfrecords'#测试数据集OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'#验证数据集OUTPUT_VALIDATION_FILE = './path/to/output_validation.tfrecords'#测试数据和验证数据的比例VALIDATION_PERCENTAGE = 10 TEST_PERCENTAGE = 10
关于VALIDATION_PERCENTAGE和TEST_PERCENTAGE这两个常量,我们在后面的例子会给出。
下面我们就来定义处理数据的方法:
def create_image_lists(sess,testing_percentage,validation_percentage): #拿到INPUT_DATA文件夹下的所有目录(包括root) sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] #如果是root_dir不需要做处理 is_root_dir = True #定义图片对应的标签,从0-4分别代表不同的花 current_label = 0 #写入TFRecord的数据需要首先定义writer #这里定义三个writer分别存储训练,测试和验证数据 writer = tf.python_io.TFRecordWriter(OUTPUT_FILE) writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE) writer_validation = tf.python_io.TFRecordWriter(OUTPUT_VALIDATION_FILE) #循环目录 for sub_dir in sub_dirs: if is_root_dir: #跳过根目录 is_root_dir = False continue #定义空数组来装图片路径 file_list = [] #生成查找路径 dir_name = os.path.basename(sub_dir) file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "jpg") # extend合并两个数组 # glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list) # 比如:glob.glob(r’c:*.txt’) 这里就是获得C盘下的所有txt文件 file_list.extend(glob.glob(file_glob)) #路径下没有文件就跳过,不继续操作 if not file_list: continue #这里我定义index来打印当前进度 index = 0 #file_list此时是图片路径列表 for file_name in file_list: #使用gfile从路径中读取图片 image_raw_data = gfile.FastGFile(file_name, 'rb').read() #对图像解码,解码结果为一个张量 image = tf.image.decode_jpeg(image_raw_data) #对图像矩阵进行归一化处理 #因为为了将图片数据能够保存到 TFRecord 结构体中 #所以需要将其图片矩阵转换成 string #所以为了在使用时能够转换回来 #这里确定下数据格式为 tf.float32 if image.dtype != tf.float32: image = tf.image.convert_image_dtype(image, dtype=tf.float32) # 将图片转化成299*299方便模型处理 image = tf.image.resize_images(image, [299, 299]) #为了拿到图片的真实数据这里我们要运行一个session op image_value = sess.run(image) pixels = image_value.shape[1] #存储在TFrecord里面的不能是array的形式 #所以我们需要利用tostring()将上面的矩阵 #转化成字符串 #再通过tf.train.BytesList转化成可以存储的形式 image_raw = image_value.tostring() #存到features #随机划分测试集和训练集 #这里存入TFRecord三个数据,图像的pixels像素 #图像原张量,这里我们需要转成string #以及当前图像对应的标签 example = tf.train.Example(features=tf.train.Features(feature={ 'pixels': _int64_feature(pixels), 'label': _int64_feature(current_label), 'image_raw': _bytes_feature(image_raw) })) chance = np.random.randint(100) #随机划分数据集 if chance < validation_percentage: writer_validation.write(example.SerializeToString()) elif chance < (testing_percentage+validation_percentage): writer_test.write(example.SerializeToString()) else: writer.write(example.SerializeToString()) # print('example',index) index = index + 1 #每一个文件夹下的所有图片都是一个类别 #所以这里每遍历完一个文件夹,标签就增加1 current_label += 1 writer.close() writer_validation.close() writer_test.close()
运行上述程序需要一定时间,我的电脑比较烂,大概跑了三十分钟左右。这时候在你的./path/to目录下可以看到output.tfrecords,output_test.tfrecords,output_validation.tfrecords三个文件,分别存放了训练,测试和验证数据集。上述代码将所有图片划分成训练、验证和测试数据集。并且把图片从原始的jpg格式转换成inception-v3模型需要的299 * 299 * 3的数字矩阵。在数据处理完毕之后,通过以下命令可以下载谷歌提供好的Inception_v3模型。
wget http://download.tensorflow.org/models/inception_v3_2016_08_26.tar.gz//解压之后可以得到训练好的模型文件inception_v3.ckpttar xzf inception_v3_2016_08
二.训练
当新的数据集和已经训练好的模型都准备好之后,我们来写代码在谷歌inception_v3的基础上训练新数据集。
首先同样我们导入相关的库并且定义相关常量。在这里我们通过slim工具来直接加载模型,而不用自己再定义前向传播过程。
import numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slim# 加载通过TensorFlow-Silm定义好的 inception_v3模型import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3# 输入数据文件INPUT_DATA = './path/to/output.tfrecords'# 验证数据集VALIDATION_DATA = './path/to/output_validation.tfrecords'# 保存训练好的模型的路径ls = './path/to/save_model'# 谷歌提供的训练好的模型文件地址CKPT_FILE = './path/to/inception_v3.ckpt'TRAIN_FILE = './path/to/save_model'# 定义训练中使用的参数LEARNING_RATE = 0.01#组合batch的大小BATCH = 32#用于one_hot函数输出概率分布N_CLASSES = 5#打乱顺序,并设置出队和入队中元素最少的个数,这里是10000个shuffle_buffer = 10000# 不需要从谷歌模型中加载的参数,这里就是最后的全连接层。因为输出类别不一样,所以最后全连接层的参数也不一样CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'# 需要训练的网络层参数 这里就是最后的全连接层TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
接下来我们定义几个辅助方法。首先因为我们的数据存在TFRecord里,需要定义方法从TFRecord解析数据。
def parse(record): features = tf.parse_single_example( record, features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), 'pixels': tf.FixedLenFeature([], tf.int64) } ) #decode_raw用于解析TFRecord里面的字符串 decoded_image = tf.decode_raw(features['image_raw'], tf.uint8) label = features['label'] #要注意这里的decoded_image并不能直接进行reshape操作 #之前我们在存储的时候,把图片进行了tostring()操作 #这会导致图片的长度在原来基础上*8 #后面我们要用到numpy的fromstring来处理 return decoded_image, label
接下来定义两个方法。因为我们已经下载了谷歌训练好的inception_v3模型的参数,下面我们需要定义两个方法从里面加载参数。
#直接从inception_v3.ckpt中读取的参数def get_tuned_variables(): #strip删除头尾字符,默认为空格 exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")] variables_to_restore = [] #这里给出了所有slim模型下的参数 for var in slim.get_model_variables(): excluded = False for exclusion in exclusions: if var.op.name.startswith(exclusion): excluded = True break if not excluded: variables_to_restore.append(var) return variables_to_restore#需要重新训练的参数def get_trainable_variables(): #strip删除头尾字符,默认为空格 scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")] variables_to_train = [] # 枚举所有需要训练的参数前缀,并通过这些前缀找到所有的参数。 for scope in scopes: #从TRAINABLE_VARIABLES集合中获取名为scope的变量 #也就是我们需要重新训练的参数 variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope) variables_to_train.extend(variables) return variables_to_train
这里我们就写完了所需要的工具函数,接下来我们定义主函数。主函数主要完成数据读取,模型定义,通过模型得出前向传播结果,通过损失函数计算损失,最后把损失交给优化器做处理。首先我们先来完成数据读取的代码,这里我们使用的是TensorFlow高层API Dataset。不清楚的可以去看一下Dataset的用法。
这里我们在训练的同时也对模型做了验证。所以我们需要加载训练和验证数据
#读取测试数据 #利用TFRecordDataset读取TFRecord文件 dataset = tf.data.TFRecordDataset([INPUT_DATA]) #解析TFRecord dataset = dataset.map(parse) #把数据打乱顺序并组装成batch dataset = dataset.shuffle(shuffle_buffer).batch(BATCH) #定义数据重复的次数 NUM_EPOCHS = 10 dataset = dataset.repeat(NUM_EPOCHS) #定义迭代器来获取处理后的数据 iterator = dataset.make_one_shot_iterator() #迭代器开始迭代 img, label = iterator.get_next() #读取验证数据(同上) valida_dataset = tf.data.TFRecordDataset([VALIDATION_DATA]) valida_dataset = valida_dataset.map(parse) valida_dataset = valida_dataset.batch(BATCH) valida_iterator = valida_dataset.make_one_shot_iterator() valida_img,valida_label = valida_iterator.get_next() #定义inception-v3的输入,images为输入图片,label为每一张图片对应的标签 #再解释下每一个维度 None为batch的大小,299为图片大小,3为通道 images = tf.placeholder(tf.float32,[None,299,299,3],name='input_images') labels = tf.placeholder(tf.int64,[None],name='labels')
要注意上述定义的只是tensorflow的张量,保存的只是计算过程并没有具体的数据。只有运行session之后才会拿到具体的数据。
下面我们来通过slim加载inception-v3模型
#定义inception-v3模型结构 inception_v3.ckpt里只有参数的取值 with slim.arg_scope(inception_v3.inception_v3_arg_scope()): #logits inception_v3前向传播得到的结果 logits,_ = inception_v3.inception_v3(images,num_classes=N_CLASSES) #获取需要训练的变量 trainable_variables = get_trainable_variables() #这里用交叉熵作为损失函数,注意一下tf.losses.softmax_cross_entropy的参数 # tf.losses.softmax_cross_entropy( # onehot_labels, # 注意此处参数名就叫 onehot_labels # logits, # weights=1.0, # label_smoothing=0, # scope=None, # loss_collection=tf.GraphKeys.LOSSES, # reduction=Reduction.SUM_BY_NONZERO_WEIGHTS # ) #这里要把labels转成one_hot类型,logits就是神经网络的输出 tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0) #把计算的损失交给优化器处理 train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss()) #计算正确率。 with tf.name_scope('evaluation'): correct_prediction = tf.equal(tf.argmax(logits,1),labels) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #定义加载模型的函数 load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True) #定义保存新的训练好的模型的函数 saver = tf.train.Saver() with tf.Session() as sess: #初始化所有变量 init = tf.global_variables_initializer() sess.run(init) print('Loading tuned variables from %s'%CKPT_FILE) #加载谷歌已经训练好的模型 load_fn(sess) step = 0; #在这里我们用一个while来循环训练,直到dataset里没有数据就结束循环 while True: try: if step % 30 == 0 or step + 1 == STEPS: #每30轮输出一次正确率 if step != 0: #每30轮保存一次当前模型的参数,以便中途训练中断可以继续 saver.save(sess,TRAIN_FILE,global_step=step) #运行session拿到真实图片的数据 valida_img_batch,valida_label_batch = sess.run([valida_img,valida_label]) #上面有提到TFRecord里图片数据被转成了string,在这里转回来 valida_img_batch = np.fromstring(valida_img_batch, dtype=np.float32) #把图片张量拉成新的维度 valida_img_batch = tf.reshape(valida_img_batch, [32, 299, 299, 3]) #用session运行上述操作,得到处理后的图片张量 valida_img_batch = sess.run(valida_img_batch) #把图片张量传到feed_dict算出正确率并显示 validation_accuracy = sess.run(evaluation_step,feed_dict={ images:valida_img_batch, labels:valida_label_batch }) print('Step %d: Validation accurary = %.1f%%'%(step,validation_accuracy*100.0)) #下面是对训练数据的操作,同上 img_batch,label_batch = sess.run([img,label]) img_batch = np.fromstring(img_batch, dtype=np.float32) img_batch = tf.reshape(img_batch, [32,299, 299, 3]) img_batch = sess.run(img_batch) sess.run(train_step,feed_dict={ images:img_batch, labels:label_batch }) #step仅仅用于记录 step = step + 1 except tf.errors.OutOfRangeError: break
运行上述程序开始训练。在这里我暂时是使用cpu进行训练,训练过程大约3小时,可以得到类型下面的结果。
step 0:Validation accuracy = 12.5% step 30:Validation accuracy = 22.2% step 60:Validation accuracy = 63.2% step 90:Validation accuracy = 79.8% step 120:Validation accuracy = 86.4% step 150:Validation accuracy = 88.5% .....
作者:sidiWang
链接:https://www.jianshu.com/p/fc77879d3591