如何将numba与functools.reduce()一起使用

我有以下代码,我试图使用并行循环,并且:numbafunctools.reduce()mul


import numpy as np

from itertools import product

from functools import reduce

from operator import mul

from numba import jit, prange


lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

arr = np.array(lst)

n = 3

flat = np.ravel(arr).tolist()

gen = np.array([list(a) for a in product(flat, repeat=n)])


@jit(nopython=True, parallel=True)

def mtp(gen):

    results = np.empty(gen.shape[0])

    for i in prange(gen.shape[0]):

        results[i] = reduce(mul, gen[i], initializer=None)

    return results

mtp(gen)

但这给了我一个错误:


---------------------------------------------------------------------------

TypingError                               Traceback (most recent call last)

<ipython-input-503-cd6ef880fd4a> in <module>

     10         results[i] = reduce(mul, gen[i], initializer=None)

     11     return results

---> 12 mtp(gen)


~\Anaconda3\lib\site-packages\numba\dispatcher.py in _compile_for_args(self, *args, **kws)

    399                 e.patch_message(msg)

    400 

--> 401             error_rewrite(e, 'typing')

    402         except errors.UnsupportedError as e:

    403             # Something unsupported is present in the user code, add help info


~\Anaconda3\lib\site-packages\numba\dispatcher.py in error_rewrite(e, issue_type)

    342                 raise e

    343             else:

--> 344                 reraise(type(e), e, None)

    345 

    346         argtypes = []


~\Anaconda3\lib\site-packages\numba\six.py in reraise(tp, value, tb)

    666             value = tp()

    667         if value.__traceback__ is not tb:

--> 668             raise value.with_traceback(tb)

    669         raise value

    670 


我不确定我哪里做错了。任何人都可以给我指出正确的方向吗?非常感谢。


慕丝7291255
浏览 82回答 1
1回答

Cats萌萌

您可以在 numba jitted 函数中使用 np.prod:n = 3lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]arr = np.array(lst)flat = np.ravel(arr).tolist()gen = [list(a) for a in product(flat, repeat=n)]@jit(nopython=True, parallel=True)def mtp(gen):&nbsp; &nbsp; results = np.empty(len(gen))&nbsp; &nbsp; for i in prange(len(gen)):&nbsp; &nbsp; &nbsp; &nbsp; results[i] = np.prod(gen[i])&nbsp; &nbsp; return results或者,您可以使用如下所示的reduce(感谢@stuartarchibald指出这一点),尽管并行化在下面不起作用(至少从numba 0.48开始):import numpy as npfrom itertools import productfrom functools import reducefrom operator import mulfrom numba import njit, prangelst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]arr = np.array(lst)n = 3flat = np.ravel(arr).tolist()gen = np.array([list(a) for a in product(flat, repeat=n)])@njitdef mul_wrapper(x, y):&nbsp; &nbsp; return mul(x, y)@njitdef mtp(gen):&nbsp; &nbsp; results = np.empty(gen.shape[0])&nbsp; &nbsp; for i in prange(gen.shape[0]):&nbsp; &nbsp; &nbsp; &nbsp; results[i] = reduce(mul_wrapper, gen[i], None)&nbsp; &nbsp; return resultsprint(mtp(gen))或者,因为Numba内部有一点魔力,可以发现将转义函数并编译它们的闭包。(再次感谢@stuartarchibald),你可以这样,在下面:@njitdef mtp(gen):&nbsp; &nbsp; results = np.empty(gen.shape[0])&nbsp; &nbsp; def op(x, y):&nbsp; &nbsp; &nbsp; &nbsp; return mul(x, y)&nbsp; &nbsp; for i in prange(gen.shape[0]):&nbsp; &nbsp; &nbsp; &nbsp; results[i] = reduce(op, gen[i], None)&nbsp; &nbsp; return results但同样,并行在numba 0.48之前在这里不起作用。请注意,核心开发团队成员推荐的方法是采用第一个使用 .它可以与并行标志一起使用,并具有更直接的实现。np.prod
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python