我有形状为 (batch_size, 200, 256) 的 LSTM 层的输出,其中 200 是标记序列的长度,256 是 LSTM 输出维度。我还有另一个形状为 (batch_size) 的张量,它是我想从批次中的每个样本序列中切出的标记的索引列表。
如果令牌索引不是 -1,我将切出一个令牌向量表示(长度 = 256)。如果令牌索引为 -1,我将给出零向量(长度 = 256)。
预期的输出结果具有形状 (batch_size, 1, 256)。我该怎么做?
谢谢
这是我到目前为止尝试过的
bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256)
dropout = Dropout(params['dropout_rate'])(bidir)
def slice_by_tensor(x):
matrix_to_slice = x[0]
index_tensor = x[1]
out_tensor = tf.where(index_tensor == -1,
tf.zeros(tf.shape(tf.gather(matrix_to_slice,
index_tensor, axis=1))),
tf.gather(matrix_to_slice, index_tensor, axis=1))
return out_tensor
representation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0])
# stack_idx0 shape is (batch_size)
# I got output with shape (batch_size, batch_size, 256) with this code
慕娘9325324
相关分类