我正在尝试获取从大列表中获取的元素子列表的所有位置。
在 Python 中,使用 numpy,假设我有
from datetime import datetime as dt
import numpy as np
from numba import jit, int64
n, N = 20, 120000
int_vocabulary = np.array(range(N))
np.random.shuffle(int_vocabulary) # to make the problem non-trivial
int_sequence = np.random.choice(int_vocabulary, n, replace=False)
我想得到
int_sequence
中的整数在大序列int_vocabulary
中的所有位置。我对快速计算很感兴趣。
到目前为止,我已经尝试使用 numba 蛮力研究、numpy 掩码方法、列表理解蛮力(用于基线)以及列表理解和 numpy 掩码混合。
@jit(int64[:](int64[:], int64[:], int64, int64))
def check(int_sequence, int_vocabulary, n, N):
all_indices = np.full(n, N)
for xi in range(n):
for i in range(N):
if int_sequence[xi] == int_vocabulary[i]:
all_indices[xi] = i
return all_indices
t0 = dt.now()
for _ in range(10):
all_indices0 = check(int_sequence, int_vocabulary, n, N)
t0 = (dt.now() - t0).total_seconds()
print("numba : ", t0)
t0 = dt.now()
for _ in range(10):
mask = np.full(len(int_vocabulary), False)
for x in int_sequence:
mask += int_vocabulary == x
all_indices1 = np.flatnonzero(mask)
t0 = (dt.now() - t0).total_seconds()
print("numpy :", t0)
t0 = dt.now()
for _ in range(10):
all_indices2 = np.array([i for i, x in enumerate(int_vocabulary)
if x in int_sequence])
t0 = (dt.now() - t0).total_seconds()
print("list comprehension : ", t0)
t0 = dt.now()
for _ in range(10):
mask = np.sum(np.array([int_vocabulary == x for x in int_sequence]), axis=0)
all_indices3 = np.flatnonzero(mask)
t0 = (dt.now() - t0).total_seconds()
print("mixed numpy + list comprehension : ", t0)
assert np.sum(all_indices0) == np.sum(all_indices1)
assert np.sum(all_indices1) == np.sum(all_indices2)
assert np.sum(all_indices2) == np.sum(all_indices3)
每次我做10次计算以获得可比较的统计数据。结果是
numba : 0.028039
numpy : 0.011616
list comprehension : 3.116753
mixed numpy + list comprehension : 0.032301
我仍然想知道是否有更快的算法来解决这个问题。