如何正确组合TensorFlow的数据集API和Keras?

Keras的fit_generator()模型方法期望生成器生成形状(输入,目标)的元组,其中两个元素都是NumPy数组。该文档似乎暗示着,如果我将Dataset迭代器简单地包装在生成器中,并确保将Tensors转换为NumPy数组,那我应该很好。这段代码给我一个错误:


import numpy as np

import os

import keras.backend as K

from keras.layers import Dense, Input

from keras.models import Model

import tensorflow as tf

from tensorflow.contrib.data import Dataset


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


with tf.Session() as sess:

    def create_data_generator():

        dat1 = np.arange(4).reshape(-1, 1)

        ds1 = Dataset.from_tensor_slices(dat1).repeat()


        dat2 = np.arange(5, 9).reshape(-1, 1)

        ds2 = Dataset.from_tensor_slices(dat2).repeat()


        ds = Dataset.zip((ds1, ds2)).batch(4)

        iterator = ds.make_one_shot_iterator()

        while True:

            next_val = iterator.get_next()

            yield sess.run(next_val)


datagen = create_data_generator()


input_vals = Input(shape=(1,))

output = Dense(1, activation='relu')(input_vals)

model = Model(inputs=input_vals, outputs=output)

model.compile('rmsprop', 'mean_squared_error')

model.fit_generator(datagen, steps_per_epoch=1, epochs=5,

                    verbose=2, max_queue_size=2)

这是我得到的错误:


Using TensorFlow backend.

Epoch 1/5

Exception in thread Thread-1:

Traceback (most recent call last):

  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__

    fetch, allow_tensor=True, allow_operation=True))

  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element

    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)

奇怪的是,next(datagen)在我初始化的位置之后直接添加包含一行datagen的代码会使代码运行正常,没有错误。


为什么我的原始代码不起作用?将行添加到代码中后,为什么它开始起作用?是否有一种更有效的方式将TensorFlow的Dataset API与Keras结合使用,而无需将Tensors转换为NumPy数组然后再次返回?


交互式爱情
浏览 969回答 3
3回答
打开App,查看更多内容
随时随地看视频慕课网APP