我有一个非常大的 numpy 数组(包含多达一百万个元素),如下所示:
[0,1,6,5,1,2,7,6,2,3,8,7,3,4,9,8,5,6,11,10,6,7,12,11,7,
8,13,12,8,9,14,13,10,11,16,15,11,12,17,16,12,13,18,17,13,
14,19,18,15,16,21,20,16,17,22,21,17,18,23,22,18,19,24,23]
和一个小字典映射,用于替换上面数组中的一些元素
{4: 0, 9: 5, 14: 10, 19: 15, 20: 0, 21: 1, 22: 2, 23: 3, 24: 0}
我想根据上面的地图替换一些元素。 numpy 数组非常大,只有一小部分元素(作为字典中的键出现)将被替换为相应的值。最快的方法是什么?
我相信还有更有效的方法,但现在,尝试一下
from numpy import copy
newArray = copy(theArray)
for k, v in d.iteritems(): newArray[theArray==k] = v
微基准测试和正确性测试:
#!/usr/bin/env python2.7
from numpy import copy, random, arange
random.seed(0)
data = random.randint(30, size=10**5)
d = {4: 0, 9: 5, 14: 10, 19: 15, 20: 0, 21: 1, 22: 2, 23: 3, 24: 0}
dk = d.keys()
dv = d.values()
def f1(a, d):
b = copy(a)
for k, v in d.iteritems():
b[a==k] = v
return b
def f2(a, d):
for i in xrange(len(a)):
a[i] = d.get(a[i], a[i])
return a
def f3(a, dk, dv):
mp = arange(0, max(a)+1)
mp[dk] = dv
return mp[a]
a = copy(data)
res = f2(a, d)
assert (f1(data, d) == res).all()
assert (f3(data, dk, dv) == res).all()
结果:
$ python2.7 -m timeit -s 'from w import f1,f3,data,d,dk,dv' 'f1(data,d)'
100 loops, best of 3: 6.15 msec per loop
$ python2.7 -m timeit -s 'from w import f1,f3,data,d,dk,dv' 'f3(data,dk,dv)'
100 loops, best of 3: 19.6 msec per loop
假设值在 0 和某个最大整数之间,可以通过使用 numpy 数组作为
int->int
字典来实现快速替换,如下所示
mp = numpy.arange(0,max(data)+1)
mp[replace.keys()] = replace.values()
data = mp[data]
首先在哪里
data = [ 0 1 6 5 1 2 7 6 2 3 8 7 3 4 9 8 5 6 11 10 6 7 12 11 7
8 13 12 8 9 14 13 10 11 16 15 11 12 17 16 12 13 18 17 13 14 19 18 15 16
21 20 16 17 22 21 17 18 23 22 18 19 24 23]
并替换为
replace = {4: 0, 9: 5, 14: 10, 19: 15, 20: 0, 21: 1, 22: 2, 23: 3, 24: 0}
我们得到
data = [ 0 1 6 5 1 2 7 6 2 3 8 7 3 0 5 8 5 6 11 10 6 7 12 11 7
8 13 12 8 5 10 13 10 11 16 15 11 12 17 16 12 13 18 17 13 10 15 18 15 16
1 0 16 17 2 1 17 18 3 2 18 15 0 3]
我对一些解决方案进行了基准测试,结果没有吸引力:
import timeit
import numpy as np
array = 2 * np.round(np.random.uniform(0,10000,300000)).astype(int)
from_values = np.unique(array) # pair values from 0 to 2000
to_values = np.arange(from_values.size) # all values from 0 to 1000
d = dict(zip(from_values, to_values))
def method_for_loop():
out = array.copy()
for from_value, to_value in zip(from_values, to_values) :
out[out == from_value] = to_value
print('Check method_for_loop :', np.all(out == array/2)) # Just checking
print('Time method_for_loop :', timeit.timeit(method_for_loop, number = 1))
def method_list_comprehension():
out = [d[i] for i in array]
print('Check method_list_comprehension :', np.all(out == array/2)) # Just checking
print('Time method_list_comprehension :', timeit.timeit(method_list_comprehension, number = 1))
def method_bruteforce():
idx = np.nonzero(from_values == array[:,None])[1]
out = to_values[idx]
print('Check method_bruteforce :', np.all(out == array/2)) # Just checking
print('Time method_bruteforce :', timeit.timeit(method_bruteforce, number = 1))
def method_searchsort():
sort_idx = np.argsort(from_values)
idx = np.searchsorted(from_values,array,sorter = sort_idx)
out = to_values[sort_idx][idx]
print('Check method_searchsort :', np.all(out == array/2)) # Just checking
print('Time method_searchsort :', timeit.timeit(method_searchsort, number = 1))
我得到了以下结果:
Check method_for_loop : True
Time method_for_loop : 2.6411612760275602
Check method_list_comprehension : True
Time method_list_comprehension : 0.07994363596662879
Check method_bruteforce : True
Time method_bruteforce : 11.960559037979692
Check method_searchsort : True
Time method_searchsort : 0.03770717792212963
“searchsort”方法几乎比“for”循环快一百倍,比 numpy 暴力方法快约 3600 倍。 列表理解方法也是代码简单性和速度之间非常好的权衡。
实现此目的的另一种更通用的方法是函数向量化:
import numpy as np
data = np.array([0, 1, 6, 5, 1, 2, 7, 6, 2, 3, 8, 7, 3, 4, 9, 8, 5, 6, 11, 10, 6, 7, 12, 11, 7, 8, 13, 12, 8, 9, 14, 13, 10, 11, 16, 15, 11, 12, 17, 16, 12, 13, 18, 17, 13, 14, 19, 18, 15, 16, 21, 20, 16, 17, 22, 21, 17, 18, 23, 22, 18, 19, 24, 23])
mapper_dict = {4: 0, 9: 5, 14: 10, 19: 15, 20: 0, 21: 1, 22: 2, 23: 3, 24: 0}
def mp(entry):
return mapper_dict[entry] if entry in mapper_dict else entry
mp = np.vectorize(mp)
print mp(data)
numpy_indexed包(免责声明:我是它的作者)为此类问题提供了一个优雅且高效的矢量化解决方案:
import numpy_indexed as npi
remapped_array = npi.remap(theArray, list(dict.keys()), list(dict.values()))
实现的方法类似于 Jean Lescut 提到的基于搜索排序的方法,但更通用。例如,数组的项不需要是整数,而是可以是任何类型,甚至是 nd 子数组本身;但它应该达到同样的性能。
没有在数组上没有 python 循环的解决方案(除了 Celil 的循环,它假设数字是“小”),所以这里有一个替代方案:
def replace(arr, rep_dict):
"""Assumes all elements of "arr" are keys of rep_dict"""
# Removing the explicit "list" breaks python3
rep_keys, rep_vals = array(list(zip(*sorted(rep_dict.items()))))
idces = digitize(arr, rep_keys, right=True)
# Notice rep_keys[digitize(arr, rep_keys, right=True)] == arr
return rep_vals[idces]
“idces”的创建方式来自这里。
使用
np.in1d
和 np.searchsorted
的完全矢量化解决方案:
replace = numpy.array([list(replace.keys()), list(replace.values())]) # Create 2D replacement matrix
mask = numpy.in1d(data, replace[0, :]) # Find elements that need replacement
data[mask] = replace[1, numpy.searchsorted(replace[0, :], data[mask])] # Replace elements
for i in xrange(len(the_array)):
the_array[i] = the_dict.get(the_array[i], the_array[i])
好吧,您需要遍历一次
theArray
,并且对于每个元素,如果它在字典中,则将其替换。
for i in xrange( len( theArray ) ):
if foo[ i ] in dict:
foo[ i ] = dict[ foo[ i ] ]
Pythonic 方式,不需要数据是整数,甚至可以是字符串:
from scipy.stats import rankdata
import numpy as np
data = np.random.rand(100000)
replace = {data[0]: 1, data[5]: 8, data[8]: 10}
arr = np.vstack((replace.keys(), replace.values())).transpose()
arr = arr[arr[:,1].argsort()]
unique = np.unique(data)
mp = np.vstack((unique, unique)).transpose()
mp[np.in1d(mp[:,0], arr),1] = arr[:,1]
data = mp[rankdata(data, 'dense')-1][:,1]
np.unique
:使用
np.unique
将数组折叠成更小的 values
。
def np_remap(arr, d):
values, inverse = np.unique(arr, return_inverse=True)
values = np.array([d[x] for x in values])
return values[inverse].reshape(arr.shape)
建议在以下情况下使用:
max(d.keys())
很大。len(d)
比 max(d.keys())
小得多。np.searchsorted
方法的时间复杂度更低。测试:
>>> d = {111: 1010, 222: 2020, 333: 3030}
>>> np_remap(np.array([333, 111, 111, 222, 333, 111]).reshape(-1, 2), d)
array([[3030, 1010],
[1010, 2020],
[3030, 1010]])