如何指定 tf.Data.iterator 的起点(或跳过前 X 个批次)?

对于张量流数据集迭代器(tf.data.Iterator),跳过前 X 个批次的最佳方法是什么,但仅在第一次迭代中,而不是在指定 repeat() 时的后续迭代)?


我尝试了以下但没有奏效:


import tensorflow as tf

import pandas as pd

from pyspark.sql import SparkSession


spark = SparkSession.builder.master('local[*]').config("spark.jars",'some/path/spark-tensorflow-connector_2.11-1.10.0.jar').getOrCreate()


df = pd.DataFrame({'x': range(10), 'y': [i*2 for i in range(10)]})

df = spark.createDataFrame(df)


df.write.format('tfrecords').option('recordType', 'Example').mode("overwrite").save('testdata')


def parse_function(proto):

    feature_description = {

    'x': tf.FixedLenFeature([], tf.int64),

    'y': tf.FixedLenFeature([], tf.int64)

    }


    parsed_features = tf.parse_single_example(proto, feature_description)


    x = parsed_features['x']

    y = parsed_features['y']


    return {'x': x, 'y': y}


def load_data(filename_pattern, parse_function, batch_size=200, skip_batches=0):

    files = tf.data.Dataset.list_files(file_pattern=filename_pattern, shuffle=False)

    dataset = tf.data.TFRecordDataset(files)


    dataset = dataset.repeat()


    dataset = dataset.map(parse_function)

    dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))


    dataset = dataset.prefetch(2)


    # Create an iterator

    iterator = dataset.make_one_shot_iterator()

    data = iterator.get_next()


    with tf.Session() as sess:

        for i in range(skip_batches):

            sess.run(data)


    return data


# skip first three batches

data = load_data('testdata/part-*', parse_function, batch_size=2, skip_batches=3)


sess = tf.Session()


for i in range(3):

    print(sess.run(data))

预期/期望:


    {'y': array([12, 14]), 'x': array([6, 7])}

    {'y': array([16, 18]), 'x': array([8, 9])}

    {'y': array([0, 2]), 'x': array([0, 1])}

实际的:


    {'y': array([0, 2]), 'x': array([0, 1])}

    {'y': array([4, 6]), 'x': array([2, 3])}

    {'y': array([8, 10]), 'x': array([4, 5])}


一只名叫tom的猫
浏览 122回答 1
1回答

摇曳的蔷薇

tf.Dataset.iterator()你为什么不跳过前 X 批,而不是通过?假设您想要 10 个批次,每个批次有 32 个元素,这意味着总共 320 个元素。因此,您可以使用tf.Dataset.skip(320)( skip ) 跳过这些,它会为您提供跳过前 10 个批次的数据集。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python