获取子列表元素所有位置的快速方法

问题描述 投票:0回答:0

我正在尝试获取从大列表中获取的元素子列表的所有位置。

在 Python 中,使用 numpy,假设我有

from datetime import datetime as dt
import numpy as np
from numba import jit, int64

n, N = 20, 120000
int_vocabulary = np.array(range(N))
np.random.shuffle(int_vocabulary)  # to make the problem non-trivial
int_sequence = np.random.choice(int_vocabulary, n, replace=False)

我想得到

int_sequence
中的整数在大序列
int_vocabulary
中的所有位置。我对快速计算很感兴趣。

到目前为止,我已经尝试使用 numba 蛮力研究、numpy 掩码方法、列表理解蛮力(用于基线)以及列表理解和 numpy 掩码混合。

@jit(int64[:](int64[:], int64[:], int64, int64))
def check(int_sequence, int_vocabulary, n, N):
    all_indices = np.full(n, N)
    for xi in range(n):
        for i in range(N):
            if int_sequence[xi] == int_vocabulary[i]:
                all_indices[xi] = i
    return all_indices

t0 = dt.now()
for _ in range(10):
    all_indices0 = check(int_sequence, int_vocabulary, n, N)
t0 = (dt.now() - t0).total_seconds()
print("numba : ", t0)

t0 = dt.now()
for _ in range(10):
    mask = np.full(len(int_vocabulary), False)
    for x in int_sequence:
        mask += int_vocabulary == x
    all_indices1 = np.flatnonzero(mask)
t0 = (dt.now() - t0).total_seconds()
print("numpy :", t0)

t0 = dt.now()
for _ in range(10):
    all_indices2 = np.array([i for i, x in enumerate(int_vocabulary)
                             if x in int_sequence])
t0 = (dt.now() - t0).total_seconds()
print("list comprehension : ", t0)

t0 = dt.now()
for _ in range(10):
    mask = np.sum(np.array([int_vocabulary == x for x in int_sequence]), axis=0)
    all_indices3 = np.flatnonzero(mask)
t0 = (dt.now() - t0).total_seconds()
print("mixed numpy + list comprehension : ", t0)

assert np.sum(all_indices0) == np.sum(all_indices1)
assert np.sum(all_indices1) == np.sum(all_indices2)
assert np.sum(all_indices2) == np.sum(all_indices3)

每次我做10次计算以获得可比较的统计数据。结果是

numba :  0.028039
numpy : 0.011616
list comprehension :  3.116753
mixed numpy + list comprehension :  0.032301

我仍然想知道是否有更快的算法来解决这个问题。

python numpy numba
© www.soinside.com 2019 - 2024. All rights reserved.