如何使用“预测”Sgature Def 在 Java 中加载 Tensorflow

我正在训练 Tensorflow Estimator 并用于export_saved_model以 SavedModel 格式保存模型。现在我想用 Tensorflow Java API 加载这个模型(我不想使用模型服务器,我需要直接用 Java 加载它)。现在的问题是,Estimator.export_saved_model仅导出“predict”signature_def,而SavedModelBundleJava中的似乎仅支持具有“serving_default”签名def的模型。所以问题是:有没有办法Estimator.export_saved_model包含“serving_default”签名 def?或者是否可以使用 java 中的“预测”签名 def 加载模型?或者还有其他我可以尝试的选择吗?


这是导出模型的代码:


feature_cols = [

        tf.feature_column.numeric_column('numeric_feature'),

        tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list('categorial_text_feature', vocabulary_list=['WORD1', 'WORD1']))

]


estimator = tf.estimator.LinearRegressor(

    feature_columns=feature_cols,

    model_dir=model_dir,

    label_dimension=1)


    estimator.train(input_fn=input_fn)


serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({

        'numeric_feature': tf.placeholder(tf.float32, shape=(None,)),

        'categorial_text_feature': tf.placeholder(tf.string, shape=(None,))

})

estimator.export_saved_model(

    export_dir_base=model_dir,

    serving_input_receiver_fn=serving_input_receiver_fn)

如果我检查模型,saved_model_cli show --tag_set serve我会得到:


The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:

SignatureDef key: "predict"

并与saved_model_cli show --tag_set serve --signature_def predict:


The given SavedModel SignatureDef contains the following input(s):

  inputs['numeric_feature'] tensor_info:

      dtype: DT_FLOAT

      shape: (-1)

      name: Placeholder:0

  inputs['categorial_text_feature'] tensor_info:

      dtype: DT_STRING

      shape: (-1)

      name: Placeholder_1:0

The given SavedModel SignatureDef contains the following output(s):

  outputs['predictions'] tensor_info:

      dtype: DT_FLOAT

      shape: (-1)

      name: linear/linear_model/linear_model/linear_model/weighted_sum:0

Method name is: tensorflow/serving/predict



茅侃侃
浏览 106回答 1
1回答

慕田峪7331174

找到了一个(不完美,但简单)的解决方法:我刚刚导出模型as_text=True:estimator.export_saved_model(         export_dir_base=model_dir,         serving_input_receiver_fn=serving_input_receiver_fn,         as_text=True)然后手动更改 .pbtxt 文件,使签名 def 称为“serving_default”
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Java