我有一个numpy数组的列表,我想检查列表中是否有给定的数组。这有一些非常奇怪的行为,我想知道如何解决它。这是问题的简单版本:
import numpy as np
x = np.array([1,1])
a = [x,1]
x in a # Returns True
(x+1) in a # Throws ValueError
1 in a # Throws ValueError
我不知道这里发生了什么。有没有解决此问题的好方法?
我正在使用Python 3.7。
编辑:确切的错误是:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
我的numpy版本是1.18.1。
[使用x in [1,x]
时,python会将x
与列表中的每个元素进行比较,在比较x == 1
期间,结果将是一个numpy数组:
>>> x == 1
array([ True, True])
并且将此数组解释为bool
值将由于固有的歧义而触发错误:
>>> bool(x == 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
您可以这样操作:
import numpy as np
x = np.array([1,1])
a = np.array([x.tolist(), 1])
x in a # True
(x+1) in a # False
1 in a # True
原因是in
或多或少被解释为
def in_sequence(elt, seq):
for i in seq:
if elt == i:
return True
return False
[1 == x
不给出False
,但引发异常,因为内部numpy将其转换为布尔数组。在大多数情况下它确实有意义,但是在这里却给出了愚蠢的行为。
听起来像个错误,但不容易修复。与1 == np.array(1, 1)
相同地处理np.array(1, 1) == np.array(1, 1)
是numpy的主要功能。将相等性比较委托给类是Python的一个主要功能。因此,我什至无法想象应该是正确的行为。
TL / DR:请不要混合使用Python列表和numpy数组,因为它们具有非常不同的语义,并且混合会导致不一致的特殊情况。
(([[EDIT:包括更一般的方法,也许更干净的方法)]
解决它的一种方法是实现NumPy安全版本的in
:import numpy as np
def in_np(x, items):
for item in items:
if isinstance(x, np.ndarray) and isinstance(item, np.ndarray) \
and x.shape == item.shape and np.all(x == item):
return True
elif isinstance(x, np.ndarray) or isinstance(item, np.ndarray):
pass
elif x == item:
return True
return False
x = np.array([1, 1])
a = [x, 1]
for y in (x, 0, 1, x + 1, np.array([1, 1, 1])):
print(in_np(y, a))
# True
# False
# True
# False
# False
或者甚至更好的是,编写具有任意比较的in
版本(可能默认为默认的in
行为),然后使用语义符合预期行为的np.array_equal()
对于==
。在代码中:
import operator def in_(x, items, eq=operator.eq): for item in items: if eq(x, item): return True return False
x = np.array([1, 1])
a = [x, 1]
for y in (x, 0, 1, x + 1, np.array([1, 1, 1])):
print(in_(y, a, np.array_equal))
# True
# False
# True
# False
# False
最后,请注意,items
可以是任何可迭代的,但是对于像O(1)
这样的哈希容器,操作的复杂度不会是set()
,尽管它仍会给出正确的结果:
print(in_(1, {1, 2, 3})) # True print(in_(0, {1, 2, 3})) # False in_(1, {1: 2, 3: 4}) # True in_(0, {1: 2, 3: 4}) # False