我使用 keras(张量流后端)训练了一个网络,并将模型保存为 json,权重保存为 h5。我现在尝试将其转换为单个张量流 pb 文件,它抱怨输出节点的名称。
系统信息:Tensorflow 2.3.0 Keras 2.4.3 Cuda 10.1 Cudnn 7
转换脚本非常简单:
import json
from tensorflow import keras
from keras import backend as K
import tensorflow as tf
json_file = "my-trained-model.json"
h5_file = "my-trained-model.h5"
Output_Path = "./trained_models/"
Frozen_pb_File = "my-trained-model.pb"
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.compat.v1.global_variables()]
# Graph -> GraphDef ProtoBuf
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
with open(json_file, 'r') as json_file:
model = keras.models.model_from_json(json_file.read())
model.load_weights(h5_file)
model.summary()
# get output node names
OutputNames = [out.op.name for out in model.outputs]
print("\nOutput Names:\n", OutputNames) # this prints "concatenate/concat" as the only output node name
# freeze the model
frozen_graph = freeze_session(tf.compat.v1.keras.backend.get_session(), output_names=OutputNames)
# save the output files
# this is the .pb file (a binary file)
tf.io.write_graph(frozen_graph, Output_Path, Frozen_pb_File, as_text=False)
当我运行这个时,
AssertionError: concatenate/concat is not in graph
因此,由于某种原因,它正在读取“concatenate/concat”的输出节点名称。下面给出模型总结,可以看到输出节点是“concatenate”;但是,即使我将输出节点名称硬编码为“连接”,我也会收到类似的断言错误:
AssertionError: concatenate is not in graph
繁华开满天机
相关分类