tensorflow in_top_k 的输入应该是 1 级还是 2 级?

我尝试尝试使用 in_top_k 函数来查看该函数到底在做什么。但我发现了一些非常令人困惑的行为。


首先我编码如下


import numpy as np

import tensorflow as tf


target = tf.constant(np.random.randint(2, size=30).reshape(30,-1), dtype=tf.int32, name="target")

pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")

result = tf.nn.in_top_k(pred, target, 1)


init = tf.global_variables_initializer()


with tf.Session() as sess:

    sess.run(init)

    targetVal = target.eval()

    predVal = pred.eval()

    resultVal = result.eval()

然后它生成以下错误:


ValueError: Shape must be rank 1 but is rank 2 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [30,1], [30,1], [].

然后我将代码更改为


import numpy as np

import tensorflow as tf

target = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")

pred = tf.constant(np.random.rand(30,1).reshape(-1), dtype=tf.float32, name="pred")

result = tf.nn.in_top_k(pred, target, 1)


init = tf.global_variables_initializer()


with tf.Session() as sess:

    sess.run(init)

    targetVal = target.eval()

    predVal = pred.eval()

    resultVal = result.eval()

但现在错误变成了


ValueError: Shape must be rank 2 but is rank 1 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [30], [30], [].

那么输入应该是 1 级还是 2 级?


叮当猫咪
浏览 136回答 1
1回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python