我有一个非常大的 numpy 数组,其中包含以下条目:
[['0/1' '2/0']
['3/0' '1/4']]
我想转换它/获取一个带有 3d 数组的数组
[[[0 1] [2 0]]
[[3 0] [1 4]]]
数组很宽,所以列很多,但行不多。该字符串大约有 100 种可能性。
我尝试过numba:
import numpy as np
import itertools
from numba import njit
import time
@njit(nopython=True)
def index_with_numba(data,int_data,indices):
for pos in indices:
str_match = str(pos[0])+'/'+str(pos[1])
for i in range(data.shape[0]):
for j in range(data.shape[1]):
if data[i, j] == str_match:
int_data[i,j] = pos
return int_data
def generate_masks():
masks=[]
def _2d_array(i,j):
return np.asarray([i,j],dtype=np.int32)
for i in range(10):
for j in range(10):
masks.append(_2d_array(i,j))
return masks
rows = 100000
cols = 200
numerators = np.random.randint(0, 10, size=(rows,cols))
denominators = np.random.randint(1, 10, size=(rows,cols))
samples = np.array([f"{numerator}/{denominator}" for numerator, denominator in zip(numerators.flatten(), denominators.flatten())],dtype=str).reshape(rows, cols)
samples_int = np.empty((samples.shape[0],samples.shape[1],2),dtype=np.int32)
#get the possibilities:
numerators = list(range(10))
denominators = list(range(10))
# Generate all possible masks
masks = generate_masks()
t0=time.time()
samples_int = index_with_numba(samples,samples_int, masks)
t1=time.time()
print(f"Time to index {t1-t0}")
但是太慢了,不可行。
Time to index 182.0304057598114
我想要这个的原因是我想编写一个 cuda 内核来根据原始值执行操作 - 所以对于“0/1”我需要 0 和 1 等,但我无法处理字符串。我本来以为可以用口罩,但好像不太合适。
任何建议表示赞赏。
由于您的整数都是个位数,因此您可以将输入数组视为
'U1'
数组:
arr_s = np.array([['0/1', '2/0'],
['3/0', '1/4']])
arr_u1 = arr_s.view('U1')
# array([['0', '/', '1', '2', '/', '0'],
# ['3', '/', '0', '1', '/', '4']], dtype='<U1')
现在,您已经知道字符串中数字的索引: 预期结果的第
[:, :, 0]
元素位于 arr_u1[:, ::3]
中,预期结果的 [:, :, 1]
元素位于 arr_u1[:, 2::3]
中
nrows, ncols = arr_s.shape
result = np.zeros((nrows, ncols, 2), dtype=int)
result[:, :, 0] = arr_u1[:, ::3].astype(int)
result[:, :, 1] = arr_u1[:, 2::3].astype(int)
这会给你预期的结果:
array([[[0, 1],
[2, 0]],
[[3, 0],
[1, 4]]])
Numba 几乎不支持 Unicode 字符串,并且目前速度非常(非常)慢。只需在调用 Numba 函数之前将 unicode 字符串转换为字节字符串,就可以使 Numba 代码速度提高多个数量级。不需要使用 GPU 来完成如此简单的任务。请注意,即使对于 Numpy 来说,转换也相当慢(这并不是很有效,因为它是为 numerical 计算而设计的,因此得名)。如果可以的话,请考虑直接将字符串数组生成为字节字符串数组,因为这里不需要 unicode。
这是一个例子:
# ~6 seconds on my machine (i5-9600KF CPU)
bsamples = samples.astype('B3')
# 1 ms instead of >180_000 ms
index_with_numba(bsamples, samples_int, masks)
这两行代码比初始代码快30 倍以上。 >99%的时间都花在第一行。因此,直接计算字节串应该比初始代码快 180 000 倍以上!