torch / np einsum 内部到底是如何工作的

torch.einsum这是有关GPU内部工作的查询。我知道如何使用einsum。它是执行所有可能的矩阵乘法,然后只选择相关的矩阵乘法,还是仅执行所需的计算?


例如,考虑形状 的两个张量a和,我希望找到形状的每个相应张量 的点积。使用einsum,代码为:b(N,P)ni(1,P)


torch.einsum('ij,ij->i',a,b)

在不使用 einsum 的情况下,获取输出的另一种方法是:


torch.diag(a @ b.t())

现在,第二个代码应该比第一个代码执行更多的计算(例如, if N= 2000,它执行2000更多的计算)。然而,当我尝试对这两个操作进行计时时,它们完成所需的时间大致相同,这就引出了一个问题。是否einsum执行所有组合(如第二个代码),并挑选出相关值?


要测试的示例代码:


import time

import torch

for i in range(100):

  a = torch.rand(50000, 256).cuda()

  b = torch.rand(50000, 256).cuda()


  t1 = time.time()

  val = torch.diag(a @ b.t())

  t2 = time.time()

  val2 = torch.einsum('ij,ij->i',a,b)

  t3 = time.time()

  print(t2-t1,t3-t2, torch.allclose(val,val2))


慕码人2483693
浏览 144回答 2
2回答

SMILET

这可能与 GPU 可以并行计算a @ b.t(). 这意味着 GPU 实际上不必等待每个行列乘法计算完成即可计算下一个乘法。如果您检查 CPU,您会发现它torch.diag(a @ b.t())比torch.einsum('ij,ij->i',a,b) 大型a和b.

莫回无

我不能代表,但几年前曾在一些细节上torch合作过。np.einsum然后它根据索引字符串构造一个自定义迭代器,仅执行必要的计算。从那时起,它以各种方式进行了重新设计,显然将问题转化为@可能的情况,从而利用了 BLAS(等)库调用。In [147]: a = np.arange(12).reshape(3,4)In [148]: b = aIn [149]: np.einsum('ij,ij->i', a,b)Out[149]: array([ 14, 126, 366])我不能确定在这种情况下使用了什么方法。通过“j”求和,还可以通过以下方式完成:In [150]: (a*b).sum(axis=1)Out[150]: array([ 14, 126, 366])正如您所注意到的,最简单的方法dot创建一个更大的数组,我们可以从中拉出对角线:In [151]: (a@b.T).shapeOut[151]: (3, 3)但这不是正确的使用方法@。 通过提供高效的“批量”处理@进行扩展。np.dot所以i维度是批次一,也是j一dot。In [152]: a[:,None,:]@b[:,:,None]Out[152]: array([[[ 14]],       [[126]],       [[366]]])In [156]: (a[:,None,:]@b[:,:,None])[:,0,0]Out[156]: array([ 14, 126, 366])换句话说,它使用 (3,1,4) 和 (3,4,1) 生成 (3,1,1),在共享大小 4 维度上进行乘积之和。一些采样时间:In [162]: timeit np.einsum('ij,ij->i', a,b)7.07 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)In [163]: timeit (a*b).sum(axis=1)9.89 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)In [164]: timeit np.diag(a@b.T)10.6 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)In [165]: timeit (a[:,None,:]@b[:,:,None])[:,0,0]5.18 µs ± 197 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python