在 PyTorch 中计算欧几里德距离而不是矩阵乘法

假设我们有 2 个矩阵:


mat = torch.randn([20, 7]) * 100

mat2 = torch.randn([7, 20]) * 100


n, m = mat.shape


最简单的常用矩阵乘法如下所示:


def mat_vec_dot_product(mat, vect):

    n, m = mat.shape

    

    res = torch.zeros([n])

    for i in range(n):

        for j in range(m):

            res[i] += mat[i][j] * vect[j]

        

    return res


res = torch.zeros([n, n])

for k in range(n):

    res[:, k] = mat_vec_dot_product(mat, mat2[:, k])

    

但是如果我需要应用 L2 范数而不是点积怎么办?代码如下:


def mat_vec_l2_mult(mat, vect):

    n, m = mat.shape

    

    res = torch.zeros([n])

    for i in range(n):

        for j in range(m):

            res[i] += (mat[i][j] - vect[j]) ** 2

            

    res = res.sqrt()

        

    return res


for k in range(n):

    res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])


我们可以使用 Torch 或任何其他库以最佳方式做到这一点吗?因为简单的 O(n^3) Python 代码运行速度非常慢。


MMMHUHU
浏览 212回答 2
2回答

慕虎7371278

用于torch.cdistL2 范数 - 欧氏距离res = torch.cdist(mat, mat2.permute(1,0), p=2)在这里,我曾经将frompermute的 dim 交换为mat27,2020,7

翻翻过去那场雪

首先,PyTorch 中的矩阵乘法有一个内置运算符:@。因此,要将 mat 和 mat2 相乘,您只需执行以下操作:mat @ mat2(假设尺寸一致,应该可以工作)。现在,要计算您似乎在第二个块中计算的平方差之和(SSD 或 L2 范数),您可以做一个简单的技巧。由于 L2 范数的平方||m_i - v||^2(其中m_i是矩阵的第 i 行M,v是向量)等于点积<m_i - v, m_i-v>- 根据您获得的点积的线性度:因此您可以通过以下方式<m_i,m_i> - 2<m_i,v> + <v,v>计算向量中每一行的 SSD:计算一次每行的 L2 范数平方、一次每行与向量之间的点积以及一次向量的 L2 范数。这可以在 中完成。然而,对于 2 个矩阵之间的 SSD,您仍然会得到MvO(n^2)O(n^3)。不过,可以通过向量化操作而不是使用循环来进行改进。这是 2 个矩阵的简单实现:def mat_mat_l2_mult(mat,mat2):&nbsp; &nbsp; rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])&nbsp; &nbsp; cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)&nbsp; &nbsp; rows_cols_dot_product = mat @ mat2&nbsp; &nbsp; ssd = rows_norm -2*rows_cols_dot_product + cols_norm&nbsp; &nbsp; return ssd.sqrt()mat = torch.randn([20, 7])mat2 = torch.randn([7,20])print(mat_mat_l2_mult(mat, mat2))所得矩阵的每个单元格将具有中每行和每列之间i,j差异的 L2 范数。imatjmat2
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python