TensorFlow使用TFRecord格式来统一存储数据,该格式可以将图像数据、标签信息、图像路径以及宽高等不同类型的信息放在一起进行统一存储,从而方便有效的管理不同的属性。
将训练数据集转成TFRecord
这里采用的数据集为目前正在做的项目的数据集,共包含两个目标文件夹(分别包含100幅图像)及对应的label.txt,label文件中的每一条内容分别对应两个文件夹中的一幅图像的路径及目标物的位置信息,即左上顶点和右下顶点的坐标信息(<x1><y1><x2><y2>),接下来我们将上面的数据制作成TFRecord文件,由于后续需要验证制作的TFRecord数据是否正确,而每张图像的尺寸并不一致,因此在生成的TFRecord文件中除了包含图像内容和标签信息,还包括了图像的宽、高及通道的信息,这样在解析图像的时候,才能把图像数据重新reshape成图像。
根据读取图像数据方式的不同,共有两种方式将自己的数据集转换成TFRecord格式,同样对应两种方式对TFRecord格式进行解析。具体代码如下:
# Convert own_data to TFRecord of TF-Example protos.import tensorflow as tffrom PIL import Imageimport numpy as npimport os# 生成整数型的属性def int64_feature(values): return tf.train.Feature(int64_list=tf.train.Int64List(value=values))# 生成浮点型的属性def float_feature(values): return tf.train.Feature(float_list=tf.train.FloatList(value=values))# 生成字符串型的属性def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))# 标签信息的地址dataset_dir = "/Users/**/**/label.txt"# 图像存放的根目录地址root_dir = '"/Users/**/**/'# 输出TFRecord文件的地址output_filename = "/Users/**/**/output.tfrecord"file_lines = open(dataset_dir).readlines()# 创建一个writer来写TFRecord文件writer = tf.python_io.TFRecordWriter(output_filename)# 统计有效数据valid_record_count = 0# 从label.txt循环读入要写入的数据信息for idx, line in enumerate(file_lines): line = line.strip('\n') image_target_path = line.split(",")[0] image_search_path = line.split(",")[1] image_labels_str = line.split(",")[2:] image_format = str(image_target_path.split('.')[-1]).lower() image_target_path = os.path.join(root_dir, image_target_path) image_search_path = os.path.join(root_dir, image_search_path) # 使用tf.gfile.FastGFile读取图像的原始数据,method_1 image_target_data = tf.gfile.FastGFile(image_target_path, 'r').read() image_search_data = tf.gfile.FastGFile(image_search_path, 'r').read() # 使用tf.image.decode_jpeg对图像进行解码,并利用img.eval().shape获得图像的宽高和通道信息 T_height, T_width, channels = tf.image.decode_jpeg(image_target_data).eval().shape S_height, S_width, channels = tf.image.decode_jpeg(image_search_data).eval().shape # 使用PIL的Image.open读取图像,method_2 image_target = Image.open(image_target_path, 'r') image_target_data = image_target.tobytes() T_height, T_width = image_target.size image_search = Image.open(image_search_path, 'r') image_search_data = image_search.tobytes() S_height, S_width = image_search.size image_labels = [float(x) for x in image_labels_str] if not len(image_labels) == 4: print("invalid label: " + line) continue # 将一个样例转化为Example Protocol Buffer,并将所有信息写入数据结构 example = tf.train.Example(features=tf.train.Features(feature={ 'image_target/encoded': bytes_feature(image_target_data), 'image_search/encoded': bytes_feature(image_search_data), 'image_target/format': bytes_feature(image_format), 'image_search/format': bytes_feature(image_format), 'image/class/label': float_feature(image_labels), 'image_target/height': int64_feature(T_height), 'image_target/width': int64_feature(T_width), 'image_search/height': int64_feature(S_height), 'image_search/width': int64_feature(S_width), 'image/channels': int64_feature(channels), 'image_target/path': bytes_feature(image_target_path), 'image_search/path': bytes_feature(image_search_path) })) # 将一个Example写入TFRecord文件 writer.write(example.SerializeToString()) valid_record_count += 1writer.close() print("\nvalid image count: " + str(valid_record_count))
读取TFRecord文件,具体代码如下:
# 使用 tf.image.decode_jpeg对jpg格式图像进行解码,对应tf.gfile读取图像,method_1image_target = tf.image.decode_jpeg(features['image_target/encoded'])# 使用tf.decode_raw将字符串解析成图像对应的像素数组,对应Image.open读取图像,method_2image_target = tf.decode_raw(features['image_target/encoded'], tf.uint8) label = features['image/class/label'] T_height = tf.cast(features['image_target/height'], tf.int32) T_width = tf.cast(features['image_target/width'], tf.int32) channels = tf.cast(features['image/channels'], tf.int32) image_target_path = features['image_target/path'] sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord)# 每次运行可以读取TFRecord文件中的一个样例for i in range(100): image_t, label_info,t_height, t_width, channnel, path = sess.run([image_target,label,T_height, T_width,channels,image_target_path]) image_name = path.split("/")[-1].split(".")[0] sample = sess.run(tf.reshape(image_t, [t_height, t_width, channnel])) image= Image.fromarray(sample,'RGB') # 以图像名称_label信息对图像命名,并进行存储 image.save(decode_path+ image_name+'_'+ str(label_info[0])+'.jpg')
作者:我是笨徒弟
链接:https://www.jianshu.com/p/9448f71e9641