Numba 和多维添加 - 不适用于 numpy.newaxis?

尝试在 python 上加速 DP 算法,numba 似乎是一个合适的候选者。


我正在用提供 3D 数组的 1D 数组减去 2D 数组。然后我使用.argmin()第三维来获得一个二维数组。这适用于 numpy,但不适用于 numba。


重现问题的玩具代码:


from numba import jit

import numpy as np


inflow      = np.arange(1,0,-0.01)                  # Dim [T]

actions     = np.arange(0,1,0.05)                   # Dim [M]

start_lvl   = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]

disc_lvl    = np.arange(0,1000)                     # Dim [O]


@jit(nopython=True)

def my_func(disc_lvl, actions, start_lvl, inflow):

    for i in range(0,100):

        # Calculate new level at time i

        new_lvl = start_lvl + inflow[i] + actions       # Dim [N x M]


        # For each new_level element, find closest discretized level

        diff    = (disc_lvl-new_lvl[:,:,np.newaxis])    # Dim [N x M x O]

        idx_lvl = abs(diff).argmin(axis=2)              # Dim [N x M]


        return True


# function works fine without numba

success = my_func(disc_lvl, actions, start_lvl, inflow)

为什么上面的代码不运行?取出时会这样@jit(nopython=True)。是否有一个工作回合可以使以下计算与 numba 一起工作?


我尝试了带有 numpy repeats 和 expand_dims 的变体,以及明确定义 jit 函数的输入类型但没有成功。


慕斯王
浏览 95回答 2
2回答

HUX布斯

您需要进行一些更改才能使其正常工作:使用 : 为 Numba 添加维度arr[:, :, None],看起来getitem更喜欢使用reshape使用np.abs而不是内置absargminwithaxis关键字参数未实现。更喜欢使用 Numba 旨在优化的循环。修复所有这些后,您可以运行 jited 函数:from numba import jitimport numpy as npinflow = np.arange(1,0,-0.01)  # Dim [T]actions = np.arange(0,1,0.05)  # Dim [M]start_lvl = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]disc_lvl = np.arange(0,1000)  # Dim [O]@jit(nopython=True)def my_func(disc_lvl, actions, start_lvl, inflow):    for i in range(0,100):        # Calculate new level at time i        new_lvl = start_lvl + inflow[i] + actions  # Dim [N x M]        # For each new_level element, find closest discretized level        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]        idx_lvl = np.empty(new_lvl.shape)        for i in range(diff.shape[0]):            for j in range(diff.shape[1]):                idx_lvl[i, j] = diff[i, j, :].argmin()        return True# function works fine without numbasuccess = my_func(disc_lvl, actions, start_lvl, inflow)

翻过高山走不出你

在我的第一篇文章的更正代码下方找到,您可以在使用和不使用 numba 库的 jitted 模式的情况下执行(通过删除以 @jit 开头的行)。我观察到这个例子的速度增加了 2 倍。from numba import jitimport numpy as npimport datetime as dtinflow = np.arange(1,0,-0.01)                       # Dim [T]nbTime = np.shape(inflow)[0]actions = np.arange(0,1,0.01)                       # Dim [M]start_lvl = np.random.rand(500).reshape(-1,1)*49    # Dim [Nx1]disc_lvl = np.arange(0,1000)                        # Dim [O]@jit(nopython=True)def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):    # Initialize result     res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))    for t in range(0,nbTime):        # Calculate new level at time t        new_lvl = start_lvl + inflow[t] + actions  # Dim [N x M]              print(t)        # For each new_level element, find closest discretized level        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]        idx_lvl = np.empty(new_lvl.shape)        for i in range(diff.shape[0]):            for j in range(diff.shape[1]):                idx_lvl[i, j] = diff[i, j, :].argmin()        res[t,:,:] = idx_lvl    return res# Call function and print running timestart_time = dt.datetime.now()result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)print('Execution time :',(dt.datetime.now() - start_time))
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python