使用 tf.gather 基于另一个张量逐行提取张量

我有两个尺寸分别为 asA:[B,3000,3]和 的张量C:[B,4000]。我想使用tf.gather()张量 C 中的每一行作为索引,并使用张量 A 中的每一行作为参数,以获得 size 的结果[B,4000,3]

这是一个更容易理解的例子:假设我有张量

A = [[1,2,3],[4,5,6],[7,8,9]],

C = [0,2,1,2,1],

结果 = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],

通过使用tf.gather(A,C). 当应用于维度小于 3 的张量时,这一切都很好。

但是当以描述为开头的情况下,通过应用tf.gather(A,C,axis=1),结果张量的形状为

[B,B,4000,3]

似乎tf.gather()只是将张量 C 中的每个元素都作为索引来收集张量 A 中的元素。我正在考虑的唯一解决方案是使用for循环,但这会极大地降低计算能力,因为使用tf.gather(A[i,...],C[i,...])以获得正确的张量的大小

[B,4000,3]

因此,是否有任何功能能够类似地完成此任务?


一只名叫tom的猫
浏览 262回答 1
1回答

慕码人8056858

你需要使用tf.gather_nd:import tensorflow as tfA = ...  # B x 3000 x 3C = ...  # B x 4000s = tf.shape(C)B, cols = s[0], s[1]# Make indices for first dimensionidx = tf.tile(tf.expand_dims(tf.range(B, dtype=C.dtype), 1), [1, cols])# Complete index for gather_ndgather_idx = tf.stack([idx, C], axis=-1)# Gather resultresult = tf.gather_nd(A, gather_idx)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python