我正在分析一些代码,发现结果令我感到惊讶np.where()。我想where()在数组的一部分上使用(知道2D数组的很大一部分与我的搜索无关),并发现它是我代码中的瓶颈。作为测试,我创建了一个新的2D数组作为该切片的副本,并测试了的速度where()。事实证明,它的运行速度明显更快。在我的实际情况中,速度的提高确实非常显着,但是我认为此测试代码仍然可以证明我的发现:
import numpy as np
def where_on_view(arr):
new_arr = np.where(arr[:, 25:75] == 5, arr[:, 25:75], np.NaN)
def where_on_copy(arr):
copied_arr = arr[:, 25:75].copy()
new_arr = np.where(copied_arr == 5, copied_arr, np.NaN)
arr = np.random.choice(np.arange(10), 1000000).reshape(1000, 1000)
而timeit结果:
%timeit where_on_view(arr)
398 µs ± 2.82 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit where_on_copy(arr)
295 µs ± 6.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
由于这两种方法都返回一个新数组,因此我不清楚如何事先获取切片的完整副本才能达到np.where()这种程度。我还进行了一些健全性检查,以确认:
在这种情况下,它们都返回相同的结果。
where() 搜索实际上仅限于切片,而不是检查整个数组,然后过滤输出。
这里:
# Sanity check that they do give the same output
test_arr = np.random.choice(np.arange(3), 25).reshape(5, 5)
test_arr_copy = test_arr[:, 1:3].copy()
print("No copy")
print(np.where(test_arr[:, 1:3] == 2, test_arr[:, 1:3], np.NaN))
print("With copy")
print(np.where(test_arr_copy == 2, test_arr_copy, np.NaN))
# Sanity check that it doesn't search the whole array
def where_on_full_array(arr):
new_arr = np.where(arr == 5, arr, np.NaN)
#%timeit where_on_full_array(arr)
#7.54 ms ± 47.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
我很好奇这种情况下增加的开销来自哪里?
婷婷同学_
相关分类