猿问

“reduceat”中最快的 Python log-sum-exp

作为统计编程包的一部分,我需要将对数转换后的值与LogSumExp 函数一起添加。这比将未记录的值加在一起效率要低得多。

此外,我需要使用numpy.ufunc.reduecat功能将值相加。

我考虑过多种选择,代码如下:

  1. (用于在非对数空间中进行比较)使用numpy.add.reduceat

  2. Numpy 的 ufunc 用于将记录的值相加:np.logaddexp.reduceat

  3. 具有以下 logsumexp 函数的手写 reduceat 函数:

def logsumexp_reduceat(arr, indices, logsum_exp_func):

    res = list()

    i_start = indices[0]

    for cur_index, i in enumerate(indices[1:]):

        res.append(logsum_exp_func(arr[i_start:i]))

        i_start = i


    res.append(logsum_exp_func(arr[i:]))

    return res


@numba.jit(nopython=True)

def logsumexp(X):

    r = 0.0

    for x in X:

        r += np.exp(x)  

    return np.log(r)


@numba.jit(nopython=True)

def logsumexp_stream(X):

    alpha = -np.Inf

    r = 0.0

    for x in X:

        if x != -np.Inf:

            if x <= alpha:

                r += np.exp(x - alpha)

            else:

                r *= np.exp(alpha - x)

                r += 1.0

                alpha = x

    return np.log(r) + alpha


arr = np.random.uniform(0,0.1, 10000)

log_arr = np.log(arr)

indices = sorted(np.random.randint(0, 10000, 100))


# approach 1

%timeit np.add.reduceat(arr, indices)

12.7 µs ± 503 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


# approach 2

%timeit np.logaddexp.reduceat(log_arr, indices)

462 µs ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# approach 3, scipy function

%timeit logsum_exp_reduceat(arr, indices, scipy.special.logsumexp)

3.69 ms ± 273 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# approach 3 handwritten logsumexp

%timeit logsumexp_reduceat(log_arr, indices, logsumexp)

139 µs ± 7.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


timeit 结果表明,带有 numba 的手写 logsumexp 函数是最快的选项,但仍然比 numpy.add.reduceat 慢 10 倍。

几个问题:

  1. 还有其他更快的方法(或对我提出的选项进行调整)吗?例如,有没有办法使用查找表来计算 logsumexp 函数?

  2. 为什么 Sebastian Nowozin 的“流式 logsumexp”函数不比天真的方法快?


杨魅力
浏览 234回答 1
1回答

largeQ

有一些改进的空间但永远不要期望 logsumexp 和标准求和一样快,因为这exp是一项相当昂贵的操作。例子import numpy as np#from version 0.43 until 0.47 this has to be set before importing numba#Bug: https://github.com/numba/numba/issues/4689from llvmlite import bindingbinding.set_option('SVML', '-vector-library=SVML')import numba as nb@nb.njit(fastmath=True,parallel=False)def logsum_exp_reduceat(arr, indices):&nbsp; &nbsp; res = np.empty(indices.shape[0],dtype=arr.dtype)&nbsp; &nbsp; for i in nb.prange(indices.shape[0]-1):&nbsp; &nbsp; &nbsp; &nbsp; r = 0.&nbsp; &nbsp; &nbsp; &nbsp; for j in range(indices[i],indices[i+1]):&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; r += np.exp(arr[j])&nbsp;&nbsp;&nbsp; &nbsp; &nbsp; &nbsp; res[i]=np.log(r)&nbsp; &nbsp; r = 0.&nbsp; &nbsp; for j in range(indices[-1],arr.shape[0]):&nbsp; &nbsp; &nbsp; &nbsp; r += np.exp(arr[j])&nbsp;&nbsp;&nbsp; &nbsp; res[-1]=np.log(r)&nbsp; &nbsp; return res计时#small example where parallelization doesn't make sensearr = np.random.uniform(0,0.1, 10_000)log_arr = np.log(arr)#use arrays if possibleindices = np.sort(np.random.randint(0, 10_000, 100))%timeit logsum_exp_reduceat(arr, indices)#without parallelzation 22 µs ± 173 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)#with parallelization&nbsp; &nbsp;84.7 µs ± 32.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)%timeit np.add.reduceat(arr, indices)#4.46 µs ± 61.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)#large example where parallelization makes sensearr = np.random.uniform(0,0.1, 1000_000)log_arr = np.log(arr)indices = np.sort(np.random.randint(0, 1000_000, 100))%timeit logsum_exp_reduceat(arr, indices)#without parallelzation 1.57 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)#with parallelization&nbsp; &nbsp;409 µs ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)%timeit np.add.reduceat(arr, indices)#340 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
随时随地看视频慕课网APP

相关分类

Python
我要回答