我有一个numpy数组的列表,我想检查列表中是否有给定的数组。这有一些非常奇怪的行为,我想知道如何解决它。这是问题的简单版本:
import numpy as np
x = np.array([1,1])
a = [x,1]
x in a # Returns True
(x+1) in a # Throws ValueError
1 in a # Throws ValueError
我不知道这里发生了什么。有没有解决此问题的好方法?
我正在使用Python 3.7。
编辑:确切的错误是:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
我的numpy版本是1.18.1。
您可以这样操作:
import numpy as np
x = np.array([1,1])
a = np.array([x.tolist(), 1])
x in a # True
(x+1) in a # False
1 in a # True
解决它的一种方法是实现NumPy安全版本的in
:
import numpy as np
def in_np(item, container):
for x in container:
if isinstance(x, np.ndarray) and isinstance(item, np.ndarray) \
and x.shape == item.shape and np.all(x == item):
return True
elif isinstance(x, np.ndarray) or isinstance(item, np.ndarray):
pass
elif x == item:
return True
return False
x = np.array([1, 1])
a = [x, 1]
for y in (x, 0, 1, x + 1, np.array([1, 1, 1])):
print(in_np(y, a))
# True
# False
# True
# False
# False
[使用x in [1,x]
时,python会将x
与列表中的每个元素进行比较,在比较x == 1
期间,结果将是一个numpy数组:
>>> x == 1
array([ True, True])
并且将此数组解释为bool
值将由于固有的歧义而触发错误:
>>> bool(x == 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
原因是in
或多或少被解释为
def in_sequence(elt, seq):
for i in seq:
if elt == seq:
return True
return False
[1 == x
不给出False
,但引发异常,因为内部numpy将其转换为布尔数组。在大多数情况下它确实有意义,但是在这里却给出了愚蠢的行为。
听起来像个错误,但不容易修复。与1 == np.array(1, 1)
相同地处理np.array(1, 1) == np.array(1, 1)
是numpy的主要功能。将相等性比较委托给类是Python的一个主要功能。因此,我什至无法想象应该是正确的行为。
TL / DR:请不要混合使用Python列表和numpy数组,因为它们具有非常不同的语义,并且混合会导致不一致的特殊情况。