我有以下问题:
我有一个形状为 (1600, 29) 的张量,我想获得轴 1 的最大值。结果应该是 (1600, 1) 张量。为了简化,我将使用 (5,3) 张量来演示我的问题:
B = tf.constant([[2, 20, 30],
[2, 7, 6],
[3, 11, 16],
[19, 1, 8],
[14, 45, 23]])
x = x = tf.math.argmax(B, 1) --> 5 Values [2 1 2 0 1]
z = tf.gather(B, x, axis=1) --> Shape: (5,5) [[30 20 30 2 20]
[ 6 7 6 2 7]
[16 11 16 3 11]
[ 8 1 8 19 1]
[23 45 23 14 45]]
好的,现在,x 给了我轴上的最大元素,但是,tf.gather 不返回 [30, 7, 16, 19, 45],而是一些奇怪的张量。如何正确“减少”尺寸?
我的“相当”肮脏的方式是这样的:
eye_z = tf.eye(5, 5)
intermed_result = z*eye_z
result = tf.linalg.matvec(intermed_result,tf.transpose(tf.constant([1,1,1,1,1], dtype=tf.float32)))
这导致正确的张量:[30. 7. 16. 19. 45.]
繁花不似锦
相关分类