检查两个numpy数组是否相同

问题描述 投票:6回答:6

假设我有一堆数组,包括xy,我想检查它们是否相等。一般来说,我可以使用np.all(x == y)(除非我现在忽略了一些愚蠢的角落情况)。

然而,这会评估整个(x == y)数组,这通常是不需要的。我的数组真的很大,我有很多,两个数组相等的概率很小,所以很可能,我真的只需要在(x == y)函数返回False之前评估一小部分all,所以这对我来说不是最佳解决方案。

我尝试使用内置的qazxsw poi函数,结合qazxsw poi:qazxsw poi

然而,在两个数组相等的情况下,这似乎要慢得多,总的来说,它不值得使用all。我认为因为内置的itertools.izip的一般目的。并且all(val1==val2 for val1,val2 in itertools.izip(x, y))不适用于发电机。

有没有办法以更快的方式做我想要的事情?

我知道这个问题类似于先前提出的问题(例如np.all),但它们没有特别涵盖提前终止的情况。

python numpy
6个回答
9
投票

在本地实现numpy之前,您可以编写自己的函数并使用all进行jit-compile:

np.all

最差的表现(数组等于)相当于Comparing two numpy arrays for equality, element-wise,并且在早期停止的情况下,编译的函数有可能大大超过numba


1
投票

import numpy as np import numba as nb @nb.jit(nopython=True) def arrays_equal(a, b): if a.shape != b.shape: return False for ai, bi in zip(a.flat, b.flat): if ai != bi: return False return True a = np.random.rand(10, 20, 30) b = np.random.rand(10, 20, 30) %timeit np.all(a==b) # 100000 loops, best of 3: 9.82 µs per loop %timeit arrays_equal(a, a) # 100000 loops, best of 3: 9.89 µs per loop %timeit arrays_equal(a, b) # 100000 loops, best of 3: 691 ns per loop 上显然正在讨论为阵列比较添加短路逻辑,因此可能会在未来的numpy版本中提供。


0
投票

您可以迭代数组的所有元素并检查它们是否相等。如果数组很可能不相等,则返回的速度比.all函数快得多。像这样的东西:

np.all

0
投票

可能是了解基础数据结构的人可以对此进行优化或解释它是否可靠/安全/良好实践,但它似乎有效。

np.all

如果我理解正确,numpy page on github会创建一个指向数据缓冲区的指针,而import numpy as np a = np.array([1, 2, 3]) b = np.array([1, 3, 4]) areEqual = True for x in range(0, a.size-1): if a[x] != b[x]: areEqual = False break else: print "a[x] is equal to b[x]\n" if areEqual: print "The tables are equal\n" else: print "The tables are not equal\n" 会创建一个可以从缓冲区中短路的本机python类型。

我认为。

编辑:进一步的测试显示它可能没有显示的那么大的时间改进。以前np.all(a==b) Out[]: True memoryview(a.data)==memoryview(b.data) Out[]: True %timeit np.all(a==b) The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 6.2 µs per loop %timeit memoryview(a.data)==memoryview(b.data) The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached. 100000 loops, best of 3: 1.85 µs per loop

ndarray.data

0
投票

嗯,我知道这是一个糟糕的答案,但似乎没有简单的方法。 Numpy Creators应该修复它。我建议:

memoryview

:)


0
投票

嗯,不是真正的答案,因为我没有检查它是否断路,但是:

a=b=np.eye(5)

从文档:

如果两个a=np.random.randint(0,10,(100,100)) b=a.copy() %timeit np.all(a==b) The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 3: 17.7 µs per loop %timeit memoryview(a.data)==memoryview(b.data) 10000 loops, best of 3: 30.1 µs per loop np.all(a==b) Out[]: True memoryview(a.data)==memoryview(b.data) Out[]: True 对象不相等,则引发AssertionError。

def compare(a, b): if len(a) > 0 and not np.array_equal(a[0], b[0]): return False if len(a) > 15 and not np.array_equal(a[:15], b[:15]): return False if len(a) > 200 and not np.array_equal(a[:200], b[:200]): return False return np.array_equal(a, b) assert_array_equal,如果不是性能敏感的代码路径。

或者遵循底层的源代码,也许它是有效的。

© www.soinside.com 2019 - 2024. All rights reserved.