在 Tensorflow 数据集 API 中拆分数据集问题

我正在读取一个tf.contrib.data.make_csv_dataset用于形成数据集的 csv 文件,然后我使用该命令take()来形成另一个只有一个元素的数据集,但它仍然返回所有元素。


这里有什么问题?我带来了下面的代码:


import tensorflow as tf

import os

tf.enable_eager_execution()


# Constants


column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']

batch_size   = 1

feature_names = column_names[:-1]

label_name = column_names[-1]


# to reorient data strucute

def pack_features_vector(features, labels):

  """Pack the features into a single array."""

  features = tf.stack(list(features.values()), axis=1)

  return features, labels


# Download the file

train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"

train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),

                                       origin=train_dataset_url)


# form the dataset

train_dataset = tf.contrib.data.make_csv_dataset(

train_dataset_fp,

batch_size, 

column_names=column_names,

label_name=label_name,

num_epochs=1)


# perform the mapping

train_dataset = train_dataset.map(pack_features_vector)


# construct a databse with one element 

train_dataset= train_dataset.take(1)


# inspect elements

for step in range(10):

    features, labels = next(iter(train_dataset))

    print(list(features))


慕侠2389804
浏览 282回答 1
1回答

守着一只汪

基于这个答案,我们可以用Dataset.take()和分割数据集Dataset.skip():train_size = int(0.7 * DATASET_SIZE)train_dataset = full_dataset.take(train_size)test_dataset = full_dataset.skip(train_size)如何修复你的代码?不要在循环中多次创建迭代器,而是使用一个迭代器:# inspect elementsfor feature, label in train_dataset:    print(feature)在您的代码中发生了什么导致这种行为?1) 内置pythoniter函数从对象获取迭代器或对象本身必须提供自己的迭代器。所以当你调用的时候iter(train_dataset),就相当于调用了Dataset.make_one_shot_iterator()。2) 默认情况下,tf.contrib.data.make_csv_dataset()shuffle 中的参数为 True ( shuffle=True)。因此,每次调用iter(train_dataset)它时都会创建包含不同数据的新迭代器。3)最后,当循环通过for step in range(10)它时,类似于创建10个不同的迭代器,大小为1,每个迭代器都有自己的数据,因为它们被打乱了。建议:如果你想避免这样的事情在循环外初始化(创建)迭代器:train_dataset = train_dataset.take(1)iterator = train_dataset.make_one_shot_iterator()# inspect elementsfor step in range(10):    features, labels = next(iterator)    print(list(features))    # throws exception because size of iterator is 1
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python