我正在尝试找出将datasetapi 与 api 一起使用的推荐方法estimator。我在网上看到的一切都是这个的一些变体:
def train_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
return dataset
然后可以将其传递给估算器的 train 函数:
classifier.train(
input_fn=train_input_fn,
#...
)
但数据集指南警告说:
上面的代码片段会将特征和标签数组作为 tf.constant() 操作嵌入到您的 TensorFlow 图中。这适用于小数据集,但会浪费内存——因为数组的内容将被多次复制——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。
然后描述一种方法,该方法涉及定义占位符,然后用 填充feed_dict:
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
但是,如果您使用的是estimatorapi,则不会手动运行会话。那么如何将datasetapi 与 estimators 一起使用,同时避免与 相关的问题from_tensor_slices()?
万千封印
相关分类