喵喵时光机
使用torch.einsum后跟torch.reshape:AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])例子:import numpy as npimport torch# A of shape (2, 3, 2):A = torch.from_numpy(np.array([[[1, 1], [2, 2], [3, 3]], [[4, 4], [5, 5], [6, 6]]]))# B of shape (2, 2, 2):B = torch.from_numpy(np.array([[[1, 1], [10, 10]], [[2, 2], [20, 20]]]))# AB of shape (2, 3*2, 2):AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])# tensor([[[ 1, 1], [ 10, 10], [ 2, 2], [ 20, 20], [ 3, 3], [ 30, 30]],# [[ 8, 8], [ 80, 80], [ 10, 10], [ 100, 100], [ 12, 12], [ 120, 120]]])