我有一个使用 Tensorflow 2 和 Keras 定义的大模型。该模型在 Python 中运行良好。现在,我想将它导入到 C++ 项目中。
在我的 C++ 项目中,我使用TF_GraphImportGraphDef
函数。*.pb
如果我使用以下代码准备文件,效果很好:
with open('load_model.pb', 'wb') as f: f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())
我已经在使用 Tensorflow 1(使用 tf.compat.v1.* 函数)编写的简单网络上尝试了这段代码。它运作良好。
现在我想将我的大模型(开头提到的,使用Tensorflow 2编写)导出到C++项目中。为此,我需要从我的模型中获取一个Graph
或GraphDef
对象。问题是:如何做到这一点?我没有找到任何属性或函数来获取它。
我也试过用它tf.saved_model.save(model, 'model')
来保存整个模型。它生成一个包含不同文件的目录,包括saved_model.pb
文件。不幸的是,当我尝试使用TF_GraphImportGraphDef
函数在 C++ 中加载此文件时,程序抛出异常。
海绵宝宝撒
宝慕林4294392
相关分类