TensorFlow:如何将图像解码器节点添加到我的图形中?

我有一个张量流模型作为冻结图,它接受一个图像张量作为输入。但是,我想向该图中添加一个新的输入图像解码器节点,以便模型也接受 jpg 图像的编码字节字符串,并最终自行解码图像。到目前为止,我已经尝试过这种方法:


model = './frozen_graph.pb'


with tf.gfile.FastGFile(model, 'rb') as f:


    # read graph

    graph_def = tf.GraphDef()

    graph_def.ParseFromString(f.read())

    tf.import_graph_def(graph_def, name="")

    g = tf.get_default_graph()


    # fetch old input

    old_input = g.get_tensor_by_name('image_tensor:0')


    # define new input

    new_input = graph_def.node.add()

    new_input.name = 'encoded_image_string_tensor'

    new_input.op = 'Substr'

    # add new input attr

    image = tf.image.decode_image(new_input, channels=3)


    # link new input to old input

    old_input.input = 'encoded_image_string_tensor'  #  must match with the name above

上面的代码返回这个异常:


Expected string passed to parameter 'input' of op 'Substr', got name: "encoded_image_string_tensor" op: "Substr"  of type 'NodeDef' instead.

我不太确定我是否可以tf.image.decode_image在图表中使用 ,所以也许有另一种方法可以解决这个问题。有人有提示吗?


守着一只汪
浏览 158回答 1
1回答

守着星空守着你

使用该input_map参数,我成功地将一个仅解码 jpg 图像的新图形映射到我的原始图形的输入(此处:)node.name='image_tensor:0'。只需确保重命名name_scope解码器图的 的(此处:)decoder。之后,您可以使用 tensorflow SavedModelBuilder 保存新的连接图。这是一个物体检测网络的例子:import tensorflow as tffrom tensorflow.python.saved_model import signature_constantsfrom tensorflow.python.saved_model import tag_constants# The export path contains the name and the version of the modelmodel = 'path/to/model.pb'export_path = './output/dir/'sigs = {}with tf.gfile.FastGFile(model, 'rb') as f:        with tf.name_scope('decoder'):                image_str_tensor = tf.placeholder(tf.string, shape=[None], name= 'encoded_image_string_tensor')                # The CloudML Prediction API always "feeds" the Tensorflow graph with                # dynamic batch sizes e.g. (?,).  decode_jpeg only processes scalar                # strings because it cannot guarantee a batch of images would have                # the same output size.  We use tf.map_fn to give decode_jpeg a scalar                # string from dynamic batches.                def decode_and_resize(image_str_tensor):                        """Decodes jpeg string, resizes it and returns a uint8 tensor."""                        image = tf.image.decode_jpeg(image_str_tensor, channels=3)                        # do additional image manipulation here (like resize etc...)                        image = tf.cast(image, dtype=tf.uint8)                        return image                image = tf.map_fn(decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)        with tf.name_scope('net'):                # load .pb file                graph_def = tf.GraphDef()                graph_def.ParseFromString(f.read())                # concatenate decoder graph and original graph                tf.import_graph_def(graph_def, name="", input_map={'image_tensor:0':image})                g = tf.get_default_graph()with tf.Session() as sess:        # load graph into session and save to new .pb file        # define model input        inp = g.get_tensor_by_name('decoder/encoded_image_string_tensor:0')        # define model outputs        num_detections = g.get_tensor_by_name('num_detections:0')        detection_scores = g.get_tensor_by_name('detection_scores:0')        detection_boxes = g.get_tensor_by_name('detection_boxes:0')        out = {'num_detections': num_detections, 'detection_scores': detection_scores, 'detection_boxes': detection_boxes}        builder = tf.saved_model.builder.SavedModelBuilder(export_path)        tensor_info_inputs = {                'inputs': tf.saved_model.utils.build_tensor_info(inp)}        tensor_info_outputs = {}        for k, v in out.items():                tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)        # assign detection signature for tensorflow serving        detection_signature = (        tf.saved_model.signature_def_utils.build_signature_def(                inputs=tensor_info_inputs,                outputs=tensor_info_outputs,                method_name=signature_constants.PREDICT_METHOD_NAME))        # "build" graph        builder.add_meta_graph_and_variables(                sess, [tf.saved_model.tag_constants.SERVING],                signature_def_map={                'detection_signature':                        detection_signature,                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:                        detection_signature,                },                main_op=tf.tables_initializer()        )        # save graph        builder.save()另外:如果您难以找到正确的输入和输出节点,您可以运行它来显示图形:graph_op = g.get_operations()for i in graph_op:    print(i.node_def)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python