我有这个张量:
tf_a1 = [[0 3 1 22]
[3 5 2 2]
[2 6 3 13]
[1 7 0 3 ]
[4 9 11 10]]
threshold我想要做的是找到在所有列中重复超过 a 的唯一值。
例如这里,3在 中重复4 columns。0中重复2 columns。2重复在3 columns等等。
我希望我的输出是这样的(假设阈值为2,因此重复超过 2 次的索引将被屏蔽)。
[[F T F F]
[T F T T]
[T F T F]
[F F F T]
[F F F F]]
这就是我所做的:
y, idx, count = tf.unique_with_counts(tf_a1)
tf.where(tf.where(count, tf_a1, tf.zeros_like(tf_a1)))
但它引发了错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: unique 需要一维向量。[操作:UniqueWithCounts]
谢谢。
森栏
鸿蒙传说
相关分类