假设我有一堆数组,包括x
和y
,我想检查它们是否相等。一般来说,我可以使用np.all(x == y)
(除非我现在忽略了一些愚蠢的角落情况)。
然而,这会评估整个(x == y)
数组,这通常是不需要的。我的数组真的很大,我有很多,两个数组相等的概率很小,所以很可能,我真的只需要在(x == y)
函数返回False之前评估一小部分all
,所以这对我来说不是最佳解决方案。
我尝试使用内置的qazxsw poi函数,结合qazxsw poi:qazxsw poi
然而,在两个数组相等的情况下,这似乎要慢得多,总的来说,它不值得使用all
。我认为因为内置的itertools.izip
的一般目的。并且all(val1==val2 for val1,val2 in itertools.izip(x, y))
不适用于发电机。
有没有办法以更快的方式做我想要的事情?
我知道这个问题类似于先前提出的问题(例如np.all
),但它们没有特别涵盖提前终止的情况。
在本地实现numpy之前,您可以编写自己的函数并使用all
进行jit-compile:
np.all
最差的表现(数组等于)相当于Comparing two numpy arrays for equality, element-wise,并且在早期停止的情况下,编译的函数有可能大大超过numba。
在import numpy as np
import numba as nb
@nb.jit(nopython=True)
def arrays_equal(a, b):
if a.shape != b.shape:
return False
for ai, bi in zip(a.flat, b.flat):
if ai != bi:
return False
return True
a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)
%timeit np.all(a==b) # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a) # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b) # 100000 loops, best of 3: 691 ns per loop
上显然正在讨论为阵列比较添加短路逻辑,因此可能会在未来的numpy版本中提供。
您可以迭代数组的所有元素并检查它们是否相等。如果数组很可能不相等,则返回的速度比.all函数快得多。像这样的东西:
np.all
可能是了解基础数据结构的人可以对此进行优化或解释它是否可靠/安全/良好实践,但它似乎有效。
np.all
如果我理解正确,numpy page on github会创建一个指向数据缓冲区的指针,而import numpy as np
a = np.array([1, 2, 3])
b = np.array([1, 3, 4])
areEqual = True
for x in range(0, a.size-1):
if a[x] != b[x]:
areEqual = False
break
else:
print "a[x] is equal to b[x]\n"
if areEqual:
print "The tables are equal\n"
else:
print "The tables are not equal\n"
会创建一个可以从缓冲区中短路的本机python类型。
我认为。
编辑:进一步的测试显示它可能没有显示的那么大的时间改进。以前np.all(a==b)
Out[]: True
memoryview(a.data)==memoryview(b.data)
Out[]: True
%timeit np.all(a==b)
The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.2 µs per loop
%timeit memoryview(a.data)==memoryview(b.data)
The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 1.85 µs per loop
ndarray.data
嗯,我知道这是一个糟糕的答案,但似乎没有简单的方法。 Numpy Creators应该修复它。我建议:
memoryview
:)
嗯,不是真正的答案,因为我没有检查它是否断路,但是:
a=b=np.eye(5)
。
从文档:
如果两个
a=np.random.randint(0,10,(100,100)) b=a.copy() %timeit np.all(a==b) The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 3: 17.7 µs per loop %timeit memoryview(a.data)==memoryview(b.data) 10000 loops, best of 3: 30.1 µs per loop np.all(a==b) Out[]: True memoryview(a.data)==memoryview(b.data) Out[]: True
对象不相等,则引发AssertionError。
def compare(a, b):
if len(a) > 0 and not np.array_equal(a[0], b[0]):
return False
if len(a) > 15 and not np.array_equal(a[:15], b[:15]):
return False
if len(a) > 200 and not np.array_equal(a[:200], b[:200]):
return False
return np.array_equal(a, b)
assert_array_equal,如果不是性能敏感的代码路径。
或者遵循底层的源代码,也许它是有效的。