猿问

在java中提供TensorFlow - 在一个会话运行中进行多个预测

我有一个保存的模型,我设法加载,运行并获得1行9个特征的预测。(输入)现在我试图预测100行这样的行,但是当尝试从Tensor.copyTo()读取结果数组时,我得到了不兼容的形状


java.lang.IllegalArgumentException: cannot copy Tensor with shape [1, 1] into object with shape [100, 1]

显然,我设法在循环中运行了这个预测 - 但这比一次运行100的等效python执行慢20倍。


这里是 /saved_model_cli.py 报告的已保存模型信息


MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:


signature_def['serving_default']:

  The given SavedModel SignatureDef contains the following input(s):

    inputs['input'] tensor_info:

        dtype: DT_FLOAT

        shape: (-1, 9)

        name: dense_1_input:0

  The given SavedModel SignatureDef contains the following output(s):

    outputs['output'] tensor_info:

        dtype: DT_FLOAT

        shape: (-1, 1)

        name: dense_4/BiasAdd:0

  Method name is: tensorflow/serving/predict

问题是 - 我是否需要为我想预测的每一行运行(),就像这里的问题一样


aluckdog
浏览 126回答 1
1回答

catspeake

好吧,所以我发现了一个问题,我无法为我想要的所有行(预测)运行一次。可能是一个张量流新手问题,我搞砸了输入和输出矩阵。当报告工具(python)说你有一个形状(-1,9)的输入张量映射到java long[]{1,9}时,这并不意味着你不能传递long[]{1000,9}的输入张量 - 这意味着1000行用于预测。在此输入之后,定义为 [1,1] 的输出张量可以是 [1000,1]。这个代码实际上比python运行得快得多(1.2秒对7秒),这是代码(也许会解释得更好)public Tensor prepareData(){&nbsp; &nbsp; Random r = new Random();&nbsp; &nbsp; float[]inputArr = new float[NUMBER_OF_KEWORDS*NUMBER_OF_FIELDS];&nbsp; &nbsp; for (int i=0;i<NUMBER_OF_KEWORDS * NUMBER_OF_FIELDS;i++){&nbsp; &nbsp; &nbsp; &nbsp; inputArr[i] = r.nextFloat();&nbsp; &nbsp; }&nbsp; &nbsp; FloatBuffer inputBuff = FloatBuffer.wrap(inputArr, 0, NUMBER_OF_KEWORDS*NUMBER_OF_FIELDS);&nbsp; &nbsp; return Tensor.create(new long[]{NUMBER_OF_KEWORDS,NUMBER_OF_FIELDS}, inputBuff);}public void predict (Tensor inputTensor){&nbsp; &nbsp; try ( Session s = savedModelBundle.session()) {&nbsp; &nbsp; &nbsp; &nbsp; Tensor result;&nbsp; &nbsp; &nbsp; &nbsp; long globalStart = System.nanoTime();&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; result = s.runner().feed("dense_1_input", inputTensor).fetch("dense_4/BiasAdd").run().get(0);&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; final long[] rshape = result.shape();&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; if (result.numDimensions() != 2 || rshape[0] <= NUMBER_OF_KEWORDS) {&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; throw new RuntimeException(&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; String.format(&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; "Expected model to produce a [N,1] shaped tensor where N is the number of labels, instead it produced one with shape %s",&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; Arrays.toString(rshape)));&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; }&nbsp; &nbsp; &nbsp; &nbsp; float[][] resultArray = (float[][]) result.copyTo(new float[NUMBER_OF_KEWORDS][1]);&nbsp; &nbsp; &nbsp; &nbsp; System.out.println(String.format("Total of %d,&nbsp; took : %.4f ms", NUMBER_OF_KEWORDS, ((double) System.nanoTime() - globalStart) / 1000000));&nbsp; &nbsp; &nbsp; &nbsp; for (int i=0;i<10;i++){&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; System.out.println(resultArray[i][0]);&nbsp; &nbsp; &nbsp; &nbsp; }&nbsp; &nbsp; }}
随时随地看视频慕课网APP

相关分类

Java
我要回答