如何根据pytorch中另一个张量的值将张量的某个值更改为零?

我有两个张量:张量 a 和张量 b。如何根据张量 b 的值更改张量 a 的某些值?


我知道下面的代码是正确的,但是当张量很大时它运行起来很慢。还有其他方法吗?


import torch

a = torch.rand(10).cuda()

b = torch.rand(10).cuda()

a[b > 0.5] = 0.


四季花海
浏览 197回答 2
2回答

MMMHUHU

对于这个确切的用例,还要考虑a * (b <= 0.5)这似乎是以下最快的In [1]: import torch&nbsp; &nbsp;...: a = torch.rand(3**10)&nbsp; &nbsp;...: b = torch.rand(3**10)In [2]: %timeit a[b > 0.5] = 0.553 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)In [3]: a = torch.rand(3**10)In [4]: %timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)&nbsp; &nbsp;...:49 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)In [5]: a = torch.rand(3**10)In [6]: %timeit temp = (a * (b <= 0.5))44 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)In [7]: %timeit a.masked_fill_(b > 0.5, 0.)244 µs ± 3.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

撒科打诨

我想torch.where会更快我在 CPU 中的测量是结果。import torcha = torch.rand(3**10)b = torch.rand(3**10)%timeit a[b > 0.5] = 0.852 µs ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)%timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)294 µs ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python