我有以下代码,我试图使用并行循环,并且:numbafunctools.reduce()mul
import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import jit, prange
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])
@jit(nopython=True, parallel=True)
def mtp(gen):
results = np.empty(gen.shape[0])
for i in prange(gen.shape[0]):
results[i] = reduce(mul, gen[i], initializer=None)
return results
mtp(gen)
但这给了我一个错误:
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-503-cd6ef880fd4a> in <module>
10 results[i] = reduce(mul, gen[i], initializer=None)
11 return results
---> 12 mtp(gen)
~\Anaconda3\lib\site-packages\numba\dispatcher.py in _compile_for_args(self, *args, **kws)
399 e.patch_message(msg)
400
--> 401 error_rewrite(e, 'typing')
402 except errors.UnsupportedError as e:
403 # Something unsupported is present in the user code, add help info
~\Anaconda3\lib\site-packages\numba\dispatcher.py in error_rewrite(e, issue_type)
342 raise e
343 else:
--> 344 reraise(type(e), e, None)
345
346 argtypes = []
~\Anaconda3\lib\site-packages\numba\six.py in reraise(tp, value, tb)
666 value = tp()
667 if value.__traceback__ is not tb:
--> 668 raise value.with_traceback(tb)
669 raise value
670
我不确定我哪里做错了。任何人都可以给我指出正确的方向吗?非常感谢。
Cats萌萌
相关分类