对于张量流数据集迭代器(tf.data.Iterator),跳过前 X 个批次的最佳方法是什么,但仅在第一次迭代中,而不是在指定 repeat() 时的后续迭代)?
我尝试了以下但没有奏效:
import tensorflow as tf
import pandas as pd
from pyspark.sql import SparkSession
spark = SparkSession.builder.master('local[*]').config("spark.jars",'some/path/spark-tensorflow-connector_2.11-1.10.0.jar').getOrCreate()
df = pd.DataFrame({'x': range(10), 'y': [i*2 for i in range(10)]})
df = spark.createDataFrame(df)
df.write.format('tfrecords').option('recordType', 'Example').mode("overwrite").save('testdata')
def parse_function(proto):
feature_description = {
'x': tf.FixedLenFeature([], tf.int64),
'y': tf.FixedLenFeature([], tf.int64)
}
parsed_features = tf.parse_single_example(proto, feature_description)
x = parsed_features['x']
y = parsed_features['y']
return {'x': x, 'y': y}
def load_data(filename_pattern, parse_function, batch_size=200, skip_batches=0):
files = tf.data.Dataset.list_files(file_pattern=filename_pattern, shuffle=False)
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.repeat()
dataset = dataset.map(parse_function)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.prefetch(2)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
with tf.Session() as sess:
for i in range(skip_batches):
sess.run(data)
return data
# skip first three batches
data = load_data('testdata/part-*', parse_function, batch_size=2, skip_batches=3)
sess = tf.Session()
for i in range(3):
print(sess.run(data))
预期/期望:
{'y': array([12, 14]), 'x': array([6, 7])}
{'y': array([16, 18]), 'x': array([8, 9])}
{'y': array([0, 2]), 'x': array([0, 1])}
实际的:
{'y': array([0, 2]), 'x': array([0, 1])}
{'y': array([4, 6]), 'x': array([2, 3])}
{'y': array([8, 10]), 'x': array([4, 5])}
摇曳的蔷薇
相关分类