假设我们有 2 个矩阵:
mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100
n, m = mat.shape
最简单的常用矩阵乘法如下所示:
def mat_vec_dot_product(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += mat[i][j] * vect[j]
return res
res = torch.zeros([n, n])
for k in range(n):
res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
但是如果我需要应用 L2 范数而不是点积怎么办?代码如下:
def mat_vec_l2_mult(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += (mat[i][j] - vect[j]) ** 2
res = res.sqrt()
return res
for k in range(n):
res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])
我们可以使用 Torch 或任何其他库以最佳方式做到这一点吗?因为简单的 O(n^3) Python 代码运行速度非常慢。
慕虎7371278
翻翻过去那场雪
相关分类