将 keras h5 转换为 tensorflow pb 以进行批量推理

我使用从keras h5模型导出的tensorflow protobuf图进行批处理推理时遇到问题。尽管导出的pb图可以接受多个输入(样本),但无论输入数量如何,它始终提供单个输出。下面是一个简单的例子来演示这个问题。


from keras.models import Model,load_model

from keras.layers import Dense, Input

from keras import backend as K

import tensorflow as tf

import numpy as np

import os

import os.path as osp


pinput = Input(shape=[10,], name='my_input')

poutput = Dense(1, activation='sigmoid')(pinput)

model = Model(inputs=[pinput], outputs=[poutput])


model.compile(loss='mean_squared_error',optimizer='sgd',metrics=['accuracy'])

data = np.random.random((100, 10))

labels = np.random.randint(2, size=(100, 1))

model.fit(data, labels, epochs=1, batch_size=32)


x = np.random.random((3, 10))

y = model.predict(x)

print y


####################################

# Save keras h5 to tensorflow pb

####################################


K.set_learning_phase(0)


#alias output names

numoutputs = 1

pred = [None]*numoutputs

pred_node_names = [None]*numoutputs

for i in range(numoutputs):

    pred_node_names[i] = 'output'+'_'+str(i)

    pred[i] = tf.identity(model.output[i], name=pred_node_names[i])

print('Output nodes names are: ', pred_node_names)


sess = K.get_session()


# Write the graph in human readable

f = 'graph_def_for_reference.pb.ascii'

tf.train.write_graph(sess.graph.as_graph_def(), '.', f, as_text=True)


input_graph_def = sess.graph.as_graph_def()


#freeze graph

from tensorflow.python.framework.graph_util import convert_variables_to_constants

output_names = pred_node_names

output_names += [v.op.name for v in tf.global_variables()]

constant_graph = convert_variables_to_constants(sess, input_graph_def,output_names)


# Write the graph in binary .pb file

from tensorflow.python.framework import graph_io

graph_io.write_graph(constant_graph, '.', 'model.pb', as_text=False)



您可以看到 keras h5 图给出了 3 个输出,而 tensorflow pb 图只给出了第一个输出。我究竟做错了什么?我想修改 h5 到 pb 的转换过程,以便我可以使用 pb 图形和 python 和 c++ tensorflow 后端进行批量推理。


凤凰求蛊
浏览 241回答 1
1回答

MYYA

事实证明,这是由于我从k2tf_convert继承的错误所致pred[i] = tf.identity(model.output[i], name=pred_node_names[i])应该pred[i] = tf.identity(model.outputs[i], name=pred_node_names[i])看来keras模型类同时具有'output'和'outputs'成员,这使得此bug难以跟踪。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python