将 keras/tensorflow h5/json 转换为 tensorflow pb 时遇到问题

我使用 keras(张量流后端)训练了一个网络,并将模型保存为 json,权重保存为 h5。我现在尝试将其转换为单个张量流 pb 文件,它抱怨输出节点的名称。


系统信息:Tensorflow 2.3.0 Keras 2.4.3 Cuda 10.1 Cudnn 7


转换脚本非常简单:


import json

from tensorflow import keras

from keras import backend as K

import tensorflow as tf


json_file = "my-trained-model.json"

h5_file = "my-trained-model.h5"

Output_Path = "./trained_models/"

Frozen_pb_File = "my-trained-model.pb"


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):

   

    from tensorflow.python.framework.graph_util import convert_variables_to_constants

    graph = session.graph

    with graph.as_default():

        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))

        output_names = output_names or []

        output_names += [v.op.name for v in tf.compat.v1.global_variables()]

        # Graph -> GraphDef ProtoBuf

        input_graph_def = graph.as_graph_def()

        if clear_devices:

            for node in input_graph_def.node:

                node.device = ""

        frozen_graph = convert_variables_to_constants(session, input_graph_def,

                                                      output_names, freeze_var_names)

        return frozen_graph



with open(json_file, 'r') as json_file:

    model = keras.models.model_from_json(json_file.read())

model.load_weights(h5_file)


model.summary()


# get output node names

OutputNames = [out.op.name for out in model.outputs]

print("\nOutput Names:\n", OutputNames)  # this prints "concatenate/concat" as the only output node name



# freeze the model

frozen_graph = freeze_session(tf.compat.v1.keras.backend.get_session(), output_names=OutputNames)


# save the output files

# this is the .pb file (a binary file)

tf.io.write_graph(frozen_graph, Output_Path, Frozen_pb_File, as_text=False)

当我运行这个时,


AssertionError: concatenate/concat is not in graph

因此,由于某种原因,它正在读取“concatenate/concat”的输出节点名称。下面给出模型总结,可以看到输出节点是“concatenate”;但是,即使我将输出节点名称硬编码为“连接”,我也会收到类似的断言错误:


AssertionError: concatenate is not in graph


红颜莎娜
浏览 70回答 1
1回答

繁华开满天机

看起来这都是由于尝试冻结 TensorFlow 2.3 模型而引起的。显然,Tensorflow 2.0+ 已弃用“冻结”概念,转而采用“保存模型”概念。一旦发现这一点,我就能够立即将 h5/json 保存到已保存的模型 pb 中。我仍然不确定这种格式是否针对推理进行了优化,所以我将对此进行一些跟进,但由于我的问题是关于我看到的错误,我想我会发布导致问题的原因。作为参考,这是我的 python 脚本,用于将 keras h5/json 文件转换为 Tensorflow 保存的模型格式。import osfrom keras.models import model_from_jsonimport tensorflow as tfimport genericpathfrom genericpath import *def splitext(p):    p = os.fspath(p)    if isinstance(p, bytes):        sep = b'/'        extsep = b'.'    else:        sep = '/'        extsep = '.'    return genericpath._splitext(p, sep, None, extsep)def load_model(path,custom_objects={},verbose=0):    from keras.models import model_from_json    path = splitext(path)[0]    with open('%s.json' % path,'r') as json_file:        model_json = json_file.read()    model = model_from_json(model_json, custom_objects=custom_objects)    model.load_weights('%s.h5' % path)    # if verbose: print 'Loaded from %s' % path    return modeljson_file = "model.json"  # the h5 file should be "model.h5"model = load_model(json_file) # load the json/h5 pairmodel.save('my_saved_model') # this is a directory name to store the saved model
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python