Pytorch张量获取具有特定值的元素的索引?

我有两个张量,张量 a 和张量 b。


我想获取张量 b 中值的所有索引。


例如。


a = torch.Tensor([1,2,2,3,4,4,4,5])

b = torch.Tensor([1,2,4])

1, 2, 4我想要张量 a的索引。我可以通过以下代码来做到这一点。


a = torch.Tensor([1,2,2,3,4,4,4,5])

b = torch.Tensor([1,2,4])

mask = torch.zeros(a.shape).type(torch.bool)

print(mask)

for e in b:

    mask = mask + (a == e)

    print(mask)

如果没有 ,我该怎么做for?


海绵宝宝撒
浏览 217回答 2
2回答

繁花不似锦

由于 PyTorch1.10和isin()(isinf()以及许多其他 numpy 等效项)也可用,因此您可以简单地执行以下操作:torch.isin(a, b)这会给你:Out[4]: tensor([ True,  True,  True, False,  True,  True,  True, False])旧答案:这是你想要的吗?:np.in1d(a.numpy(), b.numpy())将导致:array([ True,  True,  True, False,  True,  True,  True, False])

拉风的咖菲猫

如果您只是不想使用 for 循环,则可以使用列表理解:mask = [a[index] for index in b]如果甚至不想使用“for”一词,您可以随时将张量转换为 numpy 并使用 numpy 索引。mask = torch.tensor(a.numpy()[b.numpy()])更新可能误解了你的问题。在这种情况下,我想说实现这一点的最佳方法是通过列表理解。(切片可能无法实现这一点。mask = [index for index,value in enumerate(a) if value in b.tolist()]这会迭代 a 中的每个元素,获取它们的索引和值,如果该值在 b 内,则获取索引。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python