展平多个文件张量流的数据集

我正在尝试从 6 个 .bin 文件中读取 CIFAR-10 数据集,然后创建一个可初始化的迭代器。这是我下载数据的站点,它还包含二进制文件结构的描述。每个文件包含 2500 张图像。然而,生成的迭代器只为每个文件生成一个张量,一个大小为 (2500,3703) 的张量。这是我的代码


import tensorflow as tf


filename_dataset = tf.data.Dataset.list_files("cifar-10-batches-bin/*.bin")    

image_dataset = filename_dataset.map(lambda x: tf.decode_raw(tf.read_file(x), tf.float32))


iter_ = image_dataset.make_initializable_iterator()

next_file_data = iter_.get_next()I 


next_file_data = tf.reshape(next_file_data, [-1,3073])

next_file_img_data, next_file_labels = next_file_data[:,:-1], next_file_data[:,-1]

next_file_img_data = tf.reshape(next_file_img_data, [-1,32,32,3])


init_op = iter_.initializer


with tf.Session() as sess:

    sess.run(init_op)

    print(next_file_img_data.eval().shape) 



_______________________________________________________________________


>> (2500,32,32,3)

前两行基于此答案。我希望能够指定由 生成的图像数量get_next(),batch()而不是使用每个 .bin 文件中的图像数量,这里是 2500。


这里已经有一个关于展平数据集的问题,但我不清楚答案。特别是,该问题似乎包含来自在别处定义的类函数的代码片段,我不确定如何实现它。


我也试过用 来创建数据集tf.data.Dataset.from_tensor_slices(),用 替换上面的第一行


import os


filenames = [os.path.join('cifar-10-batches-bin',f) for f in os.listdir("cifar-10-batches-bin") if f.endswith('.bin')]

filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)

但这并没有解决问题。


任何帮助将不胜感激。谢谢。


料青山看我应如是
浏览 208回答 1
1回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python