如何基于另一个 NumPy 数组的值创建一个 NumPy 数组?

我想创建一个 NumPy 数组。它的元素值取决于另一个 NumPy 数组中元素的值。目前,我必须在列表理解中使用 for 循环来遍历数组a以获取b. NumPy 实现这一目标的方法是什么?


测试脚本:


import numpy as np


def get_b( a ):

    b_dict = {  1:10., 2:20., 3:30. }

    return b_dict[ a ]


a = np.full( 10, 2 )

print( f'a = {a}' )

b = np.array( [get_b(i) for i in a] )

print( f'b = {b}' )

输出:


a = [2 2 2 2 2 2 2 2 2 2]

b = [20. 20. 20. 20. 20. 20. 20. 20. 20. 20.]


陪伴而非守候
浏览 170回答 4
4回答

繁华开满天机

您可以使用np.vectorize将字典值映射到数组In [6]: b_dict = {  1:10., 2:20., 3:30 }In [7]: a = np.full( 10, 2 )In [8]: np.vectorize(b_dict.get)(a)Out[8]: array([20., 20., 20., 20., 20., 20., 20., 20., 20., 20.])

慕运维8079593

解决问题的另一种方法:from operator import itemgetternp.array(itemgetter(*a)(b_dict))输出:[20., 20., 20., 20., 20., 20., 20., 20., 20., 20.]比较:#@kmundnic solutiondef m1(a):  def get_b(x):    b_dict = {  1:10., 2:20., 3:30. }    return b_dict[x]  return np.fromiter(map(get_b, a),dtype=np.float)#@bigbounty solutiondef m2(a):  b_dict = {  1:10., 2:20., 3:30. }  return np.vectorize(b_dict.get)(a)#@Ehsan solutiondef m3(a):  b_dict = {  1:10., 2:20., 3:30. }  return np.array(itemgetter(*a)(b_dict))#@Sun Bear solutiondef m4(a):  def get_b( a ):    b_dict = {  1:10., 2:20., 3:30. }    return b_dict[ a ]  return np.array( [get_b(i) for i in a] )in_ = [np.full( n, 2 ) for n in [10,100,1000,10000]]对于small dictionary,似乎m2在大输入时最快,而m3在小输入时最快。对于更大的字典:b_dict = dict(zip(np.arange(100),np.arange(100)))in_ = [np.full(n,50) for n in [10,100,1000,10000]]m3是最快的方法。您可以根据您的字典大小和键数组大小进行选择。

摇曳的蔷薇

map使用and怎么样np.fromiter?def get_b( a ):    b_dict = {  1:10., 2:20., 3:30. }    return b_dict[ a ]a = np.full( 10, 2 )b = np.fromiter(map(get_b, a), dtype=np.float64)编辑 1:小时间比较:%timeit np.array( [get_b(i) for i in a] )5.58 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)%timeit np.fromiter(map(get_b, a), dtype=np.float64)5.77 µs ± 177 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)%timeit np.vectorize(b_dict.get)(a)12.9 µs ± 76.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)编辑 2:好像那个例子太小了:a = np.full( 1000, 2 )%timeit np.array( [get_b(i) for i in a] )415 µs ± 9.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)%timeit np.fromiter(map(get_b, a), dtype=np.float64)383 µs ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)%timeit np.vectorize(b_dict.get)(a)68.6 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

catspeake

必须b_dict是字典吗?如果你有一个数组,例如。ref = np.array([0, 10,20,30])您可以按索引快速选择值, ref[a]。在使用 numpy 时,我会尽量避免使用 dict。我发现使用 NumPy 的索引会使性能比尝试使用 python 快几个到几个数量级dict。下面是一个进行此类比较的脚本。import numpy as npfrom operator import itemgetterimport timeitimport matplotlib.pyplot as plt#@kmundnic solutiondef m1(a):    def get_b(x):        b = {  1:10., 2:20., 3:30. }        #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )        return b[x]    return np.fromiter(map(get_b, a),dtype=np.float)#@bigbounty solutiondef m2(a):    b = {  1:10., 2:20., 3:30. }    #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )    return np.vectorize(b.get)(a)#@Ehsan solutiondef m3(a):    b = {  1:10., 2:20., 3:30. }    #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )    return np.array(itemgetter(*a)(b))#@Sun Bear solutiondef m4(a):    def get_b( a ):        b = {  1:10., 2:20., 3:30. }        #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )        return b[ a ]    return np.array( [get_b(i) for i in a] )#@hpaulj solutiondef m5(a):    b = np.array([10, 20, 30])    #b = np.arange(10,1001,10)     return b[a]        sizes=[10,100,1000,10000]pm1 = []pm2 = []pm3 = []pm4 = []pm5 = []for size in sizes:    a = np.full( size, 2 )    pm1.append( timeit.timeit( 'm1(a)', number=1000, globals=globals() ) )    pm2.append( timeit.timeit( 'm2(a)', number=1000, globals=globals() ) )    pm3.append( timeit.timeit( 'm3(a)', number=1000, globals=globals() ) )    pm4.append( timeit.timeit( 'm4(a)', number=1000, globals=globals() ) )    pm5.append( timeit.timeit( 'm5(a)', number=1000, globals=globals() ) )print( 'm1 slower than m5 by :',np.array(pm1) / np.array(pm5) )print( 'm2 slower than m5 by :',np.array(pm2) / np.array(pm5) )print( 'm3 slower than m5 by :',np.array(pm3) / np.array(pm5) )print( 'm4 slower than m5 by :',np.array(pm4) / np.array(pm5) )fig = plt.figure()ax = fig.add_subplot(1, 1, 1)ax.plot( sizes, pm1, label='m1' )ax.plot( sizes, pm2, label='m2' )ax.plot( sizes, pm3, label='m3' )ax.plot( sizes, pm4, label='m4' )ax.plot( sizes, pm5, label='m5' )ax.grid( which='both' )ax.set_xscale('log')ax.set_yscale('log')ax.legend()ax.get_xaxis().set_label_text( label='len(a)', fontweight='bold' )ax.get_yaxis().set_label_text( label='Runtime (sec)', fontweight='bold' )plt.show()结果:长度 (b) = 3:m1 slower than m5 by : [  4.22462367  29.79407905  85.03454097 339.2915358 ]m2 slower than m5 by : [  8.64220685 11.57175871 13.76761749 46.1940683 ]m3 slower than m5 by : [  3.25785432  21.63131578  54.71305704 220.15777696 ]m4 slower than m5 by : [  4.60710166  30.93616607  91.8936744  371.00398273 ]长度 (b) = 100:m1 slower than m5 by : [  218.98603678  1976.50128737  9697.76615006 17742.79151719 ]m2 slower than m5 by : [  41.76535891  53.85600913 109.35129345 164.13075291 ]m3 slower than m5 by : [  24.82715462  36.77830986  87.56253196 141.04493237 ]m4 slower than m5 by : [  222.04184193  2001.72120836  9775.22464369 18431.00155305 ]
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python