猿问

python:计算向量到矩阵每一行的欧氏距离的最快方法?

考虑这个 python 代码,我在其中尝试计算向量到矩阵每一行的欧几里距离。与我能找到的使用 Tullio.jl 的最佳 Julia 版本相比,它非常慢。


python 版本需要30s而 Julia 版本只需要75ms。


我确信我在 Python 方面没有做得最好。有更快的解决方案吗?欢迎使用 Numba 和 numpy 解决方案。


import numpy as np


# generate

a = np.random.rand(4000000, 128)


b = np.random.rand(128)


print(a.shape)

print(b.shape)




def lin_norm_ever(a, b):

    return np.apply_along_axis(lambda x: np.linalg.norm(x - b), 1, a)



import time

t = time.time()

res = lin_norm_ever(a, b)

print(res.shape)

elapsed = time.time() - t

print(elapsed)

朱莉娅版本


using Tullio

function comp_tullio(a, c)

    dist = zeros(Float32, size(a, 2))

    @tullio dist[i] = (c[j] - a[j,i])^2


    dist

end

@time comp_tullio(a, c)


@benchmark comp_tullio(a, c) # 75ms on my computer


冉冉说
浏览 115回答 1
1回答

慕沐林林

为了获得最佳性能,我将在此示例中使用 Numba。我还添加了来自 Divakars 链接答案的 2 种方法以进行比较。代码import numpy as npimport numba as nbfrom scipy.spatial.distance import cdist@nb.njit(fastmath=True,parallel=True,cache=True)def dist_1(mat,vec):    res=np.empty(mat.shape[0],dtype=mat.dtype)    for i in nb.prange(mat.shape[0]):        acc=0        for j in range(mat.shape[1]):            acc+=(mat[i,j]-vec[j])**2        res[i]=np.sqrt(acc)    return res#from https://stackoverflow.com/a/52364284/4045774def dist_2(mat,vec):    return cdist(mat, np.atleast_2d(vec)).ravel()#from https://stackoverflow.com/a/52364284/4045774def dist_3(mat,vec):    M = mat.dot(vec)    d = np.einsum('ij,ij->i',mat,mat) + np.inner(vec,vec) -2*M    return np.sqrt(d)时序#Float64a = np.random.rand(4000000, 128)b = np.random.rand(128)%timeit dist_1(a,b)#122 ms ± 3.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)%timeit dist_2(a,b)#484 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)%timeit dist_3(a,b)#432 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)#Float32a = np.random.rand(4000000, 128).astype(np.float32)b = np.random.rand(128).astype(np.float32)%timeit dist_1(a,b)#68.6 ms ± 414 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)%timeit dist_2(a,b)#2.2 s ± 32.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)#looks like there is a costly type-casting to float64%timeit dist_3(a,b)#228 ms ± 8.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
随时随地看视频慕课网APP

相关分类

Python
我要回答