猿问

两个张量的Pytorch广播产品

我想将两个张量相乘,这是我得到的:

  • A 形状张量 (20, 96, 110)

  • B 形状张量 (20, 16, 110)

第一个索引用于批次大小。我想要做的基本上是从B-中获取每个张量(20, 1, 110),例如,然后,我希望将每个A张量相乘(20, n, 110)。因此乘积将在最后:AB形状为的张量(20, 96 * 16, 110)

所以我想A通过广播将每个张量相乘B。在PyTorch中有做到这一点的方法吗?


拉丁的传说
浏览 176回答 1
1回答

喵喵时光机

使用torch.einsum后跟torch.reshape:AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])例子:import numpy as npimport torch# A of shape (2, 3, 2):A = torch.from_numpy(np.array([[[1, 1], [2, 2], [3, 3]],                                [[4, 4], [5, 5], [6, 6]]]))# B of shape (2, 2, 2):B = torch.from_numpy(np.array([[[1, 1], [10, 10]],                                [[2, 2], [20, 20]]]))# AB of shape (2, 3*2, 2):AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])# tensor([[[ 1, 1], [ 10, 10], [  2,  2], [ 20,   20], [ 3,   3], [ 30,  30]],#         [[ 8, 8], [ 80, 80], [ 10, 10], [ 100, 100], [ 12, 12], [ 120, 120]]])
随时随地看视频慕课网APP

相关分类

Python
我要回答