如何索引具有形状 (batch_size, 200, 256) 的张量以获得

我有形状为 (batch_size, 200, 256) 的 LSTM 层的输出,其中 200 是标记序列的长度,256 是 LSTM 输出维度。我还有另一个形状为 (batch_size) 的张量,它是我想从批次中的每个样本序列中切出的标记的索引列表。


如果令牌索引不是 -1,我将切出一个令牌向量表示(长度 = 256)。如果令牌索引为 -1,我将给出零向量(长度 = 256)。


预期的输出结果具有形状 (batch_size, 1, 256)。我该怎么做?


谢谢


这是我到目前为止尝试过的


bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) 

dropout = Dropout(params['dropout_rate'])(bidir)

def slice_by_tensor(x):

    matrix_to_slice = x[0]

    index_tensor = x[1]



    out_tensor = tf.where(index_tensor == -1, 

                          tf.zeros(tf.shape(tf.gather(matrix_to_slice, 

                                                      index_tensor, axis=1))), 

                          tf.gather(matrix_to_slice, index_tensor, axis=1))




    return out_tensor



representation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0]) 

# stack_idx0 shape is (batch_size) 

# I got output with shape (batch_size, batch_size, 256) with this code


HUX布斯
浏览 128回答 1
1回答

慕娘9325324

a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))#     [[[ 0,  1,  2,  3],#        [ 4,  5,  6,  7],#        [ 8,  9, 10, 11]],#      [[12, 13, 14, 15],#      [16, 17, 18, 19],#       [20, 21, 22, 23]]]b=tf.constant([-1,2]) aa=tf.pad(a,[[0,0],[1,0],[0,0]]) bb=b+1 index=tf.stack([tf.range(tf.size(b)),bb],axis=-1) res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)#[[[ 0,  0,  0,  0]],#[[20, 21, 22, 23]]]当 index 为 -1 时,我们需要像张量这样的零。所以我们可以先沿第二个轴填充原始张量。然后将索引增加 1。在此之后,使用tf.gather_nd将返回答案。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python