获取数组中匹配元素的索引,同时考虑重复

问题描述 投票:3回答:3

我想要类似于Numpy中带有两个数组的SQL WHERE表达式。假设我有两个这样的数组:

import numpy as np
dt = np.dtype([('f1', np.uint8), ('f2', np.uint8), ('f3', np.float_)])
a = np.rec.fromarrays([[3,    4,    4,   7,    9,    9],
                       [1,    5,    5,   4,    2,    2],
                       [2.0, -4.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
b = np.rec.fromarrays([[ 1,    4,   7,    9,    9],
                       [ 7,    5,   4,    2,    2],
                       [-3.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)

我想返回原始数组的索引,以便覆盖每个可能的匹配对。另外,我想利用两个数组都被排序的事实,因此不需要最坏情况的O(mn)算法。在这种情况下,由于(4, 5, -4.5)匹配,但在第一个数组中出现两次,它将在结果索引中出现两次,并且由于(9, 2, 24.3)在两者中出现两次,因此总共将发生4次。由于(3, 1, 2.0)不会出现在第二个数组中,因此将跳过它,第二个数组中的(1, 7, -3.5)也会被跳过。该功能应该适用于任何dtype

在这种情况下,结果将是这样的:

a_idx, b_idx = match_arrays(a, b)
a_idx = np.array([1, 2, 3, 4, 4, 5, 5])
b_idx = np.array([1, 1, 2, 3, 4, 3, 4])

具有相同输出的另一个示例:

dt2 = np.dtype([('f1', np.uint8), ('f2', dt)])
a2 = np.rec.fromarrays([[3, 4, 4, 7, 9, 9], a], dtype=dt2)
b2 = np.rec.fromarrays([[1, 4, 7, 9, 9], b], dtype=dt2)

我有一个纯Python实现,但它在我的用例中作为糖蜜很慢。我希望有更多的矢量化。这是我到目前为止所拥有的:

def match_arrays(a, b):
    len_a = len(a)
    len_b = len(b)

    a_idx = []
    b_idx = []

    i, j = 0, 0

    first_matched_j = 0

    while i < len_a and j < len_b:
        matched = False
        j = first_matched_j

        while j < len_b and a[i] == b[j]:
            a_idx.append(i)
            b_idx.append(j)
            if not matched:
                matched = True
                first_matched_j = j

            j += 1
        else:
            i += 1

        j = first_matched_j

        while i < len_a and j < len_b and a[i] > b[j]:
            j += 1
            first_matched_j = j

        while i < len_a and j < len_b and a[i] < b[j]:
            i += 1

    return np.array(a_idx), np.array(b_idx)

编辑:正如Divakar在他的answer指出,我可以使用a_idx, b_idx = np.where(np.equal.outer(a, b))。然而,这似乎是最糟糕的O(mn)解决方案,我想通过预先排序数组来避免。特别是,如果没有任何重复,那就是O(m + n)会很棒。

编辑2:如果只使用Numpy,Paul Panzeranswer不是O(m + n),但它通常更快。此外,他提供了O(m + n)答案,所以我接受了那个。我很快就会使用timeit进行性能比较。

编辑3:这是性能结果,如承诺:

╔════════════════╦═══════════════════╦═══════════════════╦═══════════════════╦══════════════════╦═══════════════════╗
║ User           ║ Version           ║ n = 10 ** 2       ║ n = 10 ** 4       ║ n = 10 ** 6      ║ n = 10 ** 8       ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Paul Panzer    ║ USE_HEAPQ = False ║ 115 µs ± 385 ns   ║ 793 µs ± 8.43 µs  ║ 105 ms ± 1.57 ms ║ 18.2 s ± 116 ms   ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ USE_HEAPQ = True  ║ 189 µs ± 3.6 µs   ║ 6.38 ms ± 28.8 µs ║ 650 ms ± 2.49 ms ║ 1min 11s ± 420 ms ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ SigmaPiEpsilon ║ Generator         ║ 936 µs ± 1.52 µs  ║ 9.17 s ± 57 ms    ║ N/A              ║ N/A               ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ for loop          ║ 144 µs ± 526 ns   ║ 15.6 ms ± 18.6 µs ║ 1.74 s ± 33.9 ms ║ N/A               ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Divakar        ║ np.where          ║ 39.1 µs ± 281 ns  ║ 302 ms ± 4.49 ms  ║ Out of memory    ║ N/A               ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ recarrays 1       ║ 69.9 µs ± 491 ns  ║ 1.6 ms ± 24.2 µs  ║ 230 ms ± 3.52 ms ║ 41.5 s ± 543 ms   ║
║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║                ║ recarrays 2       ║ 82.6 µs ± 1.01 µs ║ 1.4 ms ± 4.51 µs  ║ 212 ms ± 2.59 ms ║ 36.7 s ± 900 ms   ║
╚════════════════╩═══════════════════╩═══════════════════╩═══════════════════╩══════════════════╩═══════════════════╝

所以看起来像Paul Panzeranswer赢得USE_HEAPQ = False。我期待USE_HEAPQ = True赢得大量投入,因为它是O(m + n),但结果却并非如此。另一个评论是,USE_HEAPQ = False版本使用的内存较少,最大为5.79 GB,USE_HEAPQ = Truen = 10 ** 8为10.18 GB。请记住,这是进程内存,包括控制台的输入和其他内容。 Divakar的重新排列答案1使用了8.42 GB的内存和重新排列的答案2使用了10.61 GB。

python numpy matching
3个回答
2
投票

这是一个O(n)-ish解决方案(因为如果重复很长,它显然不能是O(n))。在实践中,根据输入长度,可能通过牺牲O(n)并用稳定的heapq.merge替换np.argsort来节省一点。目前,N = 10 ^ 6需要大约一秒钟。

码:

import numpy as np

USE_HEAPQ = True

def sqlwhere(a, b):
    asw = np.r_[0, 1 + np.flatnonzero(a[:-1]!=a[1:]), len(a)]
    bsw = np.r_[0, 1 + np.flatnonzero(b[:-1]!=b[1:]), len(b)]
    al, bl = np.diff(asw), np.diff(bsw)
    na, nb = len(al), len(bl)
    abunq = np.r_[a[asw[:-1]], b[bsw[:-1]]]
    if USE_HEAPQ:
        from heapq import merge
        m = np.fromiter(merge(range(na), range(na, na+nb), key=abunq.__getitem__), int, na+nb)
    else:
        m = np.argsort(abunq, kind='mergesort')
    mv = abunq[m]
    midx = np.flatnonzero(mv[:-1]==mv[1:])
    ai, bi = m[midx], m[midx+1] - na
    aic = np.r_[0, np.cumsum(al[ai])]
    a_idx = np.ones((aic[-1],), dtype=int)
    a_idx[aic[:-1]] = asw[ai]
    a_idx[aic[1:-1]] -= asw[ai[:-1]] + al[ai[:-1]] - 1
    a_idx = np.repeat(np.cumsum(a_idx), np.repeat(bl[bi], al[ai]))
    bi = np.repeat(bi, al[ai])
    bic = np.r_[0, np.cumsum(bl[bi])]
    b_idx = np.ones((bic[-1],), dtype=int)
    b_idx[bic[:-1]] = bsw[bi]
    b_idx[bic[1:-1]] -= bsw[bi[:-1]] + bl[bi[:-1]] - 1
    b_idx = np.cumsum(b_idx)
    return a_idx, b_idx

def f_D(a, b):
    return np.where(np.equal.outer(a,b))

def mock_data(n):
    return np.cumsum(np.random.randint(0, 3, (2, n)), axis=1)


a = np.array([3, 4, 4, 7, 9, 9], dtype=np.uint8)
b = np.array([1, 4, 7, 9, 9], dtype=np.uint8)

# check correct
a, b = mock_data(1000)
ai0, bi0 = f_D(a, b)
ai1, bi1 = sqlwhere(a, b)
print(np.all(ai0 == ai1), np.all(bi0 == bi1))

# check fast
a, b = mock_data(1000000)
sqlwhere(a, b)

2
投票

方法#1:基于广播的方法

使用两个数组之间的outer相等比较来利用矢量化的broadcasting然后获得行,列索引,这将是非常需要匹配对应于两个数组的索引 -

a_idx, b_idx = np.where(a[:,None]==b)
a_idx, b_idx = np.where(np.equal.outer(a,b))

我们也可以使用np.nonzero代替np.where

方法#2:具体案例解决方案

没有重复和排序的输入数组,我们可以使用np.searchsorted,就像这样 -

idx0 = np.searchsorted(a,b)
idx1 = np.searchsorted(b,a)
idx0[idx0==len(a)] = 0
idx1[idx1==len(b)] = 0

a_idx = idx0[a[idx0] == b]
b_idx = idx1[b[idx1] == a]

稍微修改一下,可能会更有效率 -

idx0 = np.searchsorted(a,b)
idx0[idx0==len(a)] = 0

a_idx = idx0[a[idx0] == b]
b_idx = np.searchsorted(b,a[a_idx])

方法#3:通用案例

这是一般案例的解决方案(允许重复) -

def findwhere(a, b):
    c = np.bincount(b, minlength=a.max()+1)[a]
    a_idx1 = np.repeat(np.flatnonzero(c),c[c!=0])

    b_idx1 = np.searchsorted(b,a[a_idx1])
    m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
    idx11 = np.flatnonzero(m1[1:] != m1[:-1])
    id_arr = m1.astype(int)
    id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
    b_idx1 += id_arr.cumsum()[:-1]
    return a_idx1, b_idx1

Timings

使用来自@Paul Panzer的soln的mock_data来设置输入:

In [295]: a, b = mock_data(1000000)

# @Paul Panzer's soln
In [296]: %timeit sqlwhere(a, b) # USE_HEAPQ = False
10 loops, best of 3: 118 ms per loop

# Approach #3 from this post
In [297]: %timeit findwhere(a,b)
10 loops, best of 3: 61.7 ms per loop

用于将重新排列(uint8数据)转换为1D数组的实用程序

def convert_recarrays_to_1Darrs(a, b):
    a2D = a.view('u1').reshape(-1,2)
    b2D = b.view('u1').reshape(-1,2)
    s = max(a2D[:,0].max(), b2D[:,0].max())+1

    a1D = s*a2D[:,1] + a2D[:,0]
    b1D = s*b2D[:,1] + b2D[:,0]
    return a1D, b1D

样品运行 -

In [90]: dt = np.dtype([('f1', np.uint8), ('f2', np.uint8)])
    ...: a = np.rec.fromarrays([[3, 4, 4, 7, 9, 9],
    ...:                        [1, 5, 5, 4, 2, 2]], dtype=dt)
    ...: b = np.rec.fromarrays([[1, 4, 7, 9, 9],
    ...:                        [7, 5, 4, 2, 2]], dtype=dt)

In [91]: convert_recarrays_to_1Darrs(a, b)
Out[91]: 
(array([13, 54, 54, 47, 29, 29], dtype=uint8),
 array([71, 54, 47, 29, 29], dtype=uint8))

用于覆盖rec-arrays的通用版本

版本#1:

def findwhere_generic_v1(a, b):
    cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
    count = np.diff(cidx)
    b_starts = b[cidx[:-1]]

    a_starts = np.searchsorted(a,b_starts)
    a_starts[a_starts==len(a)] = 0

    valid_mask = (b_starts == a[a_starts])
    count_valid = count[valid_mask]

    idx2m0 = np.searchsorted(a,b_starts[valid_mask],'right')    
    idx1m0 = a_starts[valid_mask]

    id_arr = np.zeros(len(a)+1, dtype=int)
    id_arr[idx2m0] -= 1
    id_arr[idx1m0] += 1

    n = idx2m0 - idx1m0
    r1 = np.flatnonzero(id_arr.cumsum()!=0)
    r2 = np.repeat(count_valid,n)
    a_idx1 = np.repeat(r1, r2)

    b_idx1 = np.searchsorted(b,a[a_idx1])
    m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
    idx11 = np.flatnonzero(m1[1:] != m1[:-1])
    id_arr = m1.astype(int)
    id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
    b_idx1 += id_arr.cumsum()[:-1]
    return a_idx1, b_idx1

版本#2:

def findwhere_generic_v2(a, b):    
    cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
    count = np.diff(cidx)
    b_starts = b[cidx[:-1]]

    idxx = np.flatnonzero(np.r_[True,a[1:] != a[:-1],True])
    av = a[idxx[:-1]]
    idxxs = np.searchsorted(av,b_starts)
    idxxs[idxxs==len(av)] = 0
    valid_mask0 = av[idxxs] == b_starts

    starts = idxx[idxxs]
    stops = idxx[idxxs+1]

    idx1m0 = starts[valid_mask0]
    idx2m0 = stops[valid_mask0]  

    count_valid = count[valid_mask0]

    id_arr = np.zeros(len(a)+1, dtype=int)
    id_arr[idx2m0] -= 1
    id_arr[idx1m0] += 1

    n = idx2m0 - idx1m0
    r1 = np.flatnonzero(id_arr.cumsum()!=0)
    r2 = np.repeat(count_valid,n)
    a_idx1 = np.repeat(r1, r2)

    b_idx1 = np.searchsorted(b,a[a_idx1])
    m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
    idx11 = np.flatnonzero(m1[1:] != m1[:-1])
    id_arr = m1.astype(int)
    id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
    b_idx1 += id_arr.cumsum()[:-1]
    return a_idx1, b_idx1

1
投票

纯Python方法

Generator comprehension

另一种带有生成器和列表推导的纯python实现。与代码相比,内存效率可能更高,但与numpy版本相比可能会更慢。对于排序数组,这将更快。

def pywheregen(a, b):

    l = ((ia,ib) for ia,j in enumerate(a) for ib,k in enumerate(b) if j == k)
    a_idx,b_idx = zip(*l)
    return a_idx,b_idx

Python for loop considering array sorting

这是一个使用简单python for循环的替代版本,并考虑到数组已排序,以便它只检查它需要的对。

def pywhere(a, b):

    l = []
    a.sort()
    b.sort()
    match = 0
    for ia,j in enumerate(a):
        ib = match
        while ib < len(b) and j >= b[ib]:
            if j == b[ib]:
                l.append(((ia,ib)))
                if b[match] < b[ib]:
                    match = ib
            ib += 1

    a_ind,b_ind = zip(*l)

    return a_ind, b_ind

Timings

我使用@Paul Panzer的mock_data()函数比较了时间,并将它与findwhere()f_D()与@Divakar的np.outer方法进行了比较。 findwhere()仍然表现最好,但pywhere()并不是那么糟糕,因为它是纯粹的蟒蛇。 pywheregen()失败,令人惊讶的是f_D()需要更长的时间。他们都失败了N = 10 ^ 6。由于heapq模块中的无关错误,我无法运行sqlwhere。

In [2]: a, b = mock_data(10000)
In [10]: %timeit -n 10 findwhere(a,b)                                     
10 loops, best of 3: 1.62 ms per loop

In [11]: %timeit -n 10 pywhere(a,b)                                       
10 loops, best of 3: 20.6 ms per loop

In [12]: %timeit pywheregen(a,b)                                          
1 loop, best of 3: 12.7 s per loop

In [13]: %timeit -n 10 f_D(a,b)                                           
10 loops, best of 3: 476 ms per loop

In [14]: a, b = mock_data(1000000)
In [15]: %timeit -n 10 findwhere(a,b)                                     
10 loops, best of 3: 109 ms per loop

In [16]: %timeit -n 10 pywhere(a,b)                                       
10 loops, best of 3: 2.51 s per loop
© www.soinside.com 2019 - 2024. All rights reserved.