检查NumPy数组是否在Python列表中

问题描述 投票:2回答:4

我有一个numpy数组的列表,我想检查列表中是否有给定的数组。这有一些非常奇怪的行为,我想知道如何解决它。这是问题的简单版本:

import numpy as np
x = np.array([1,1])
a = [x,1]

x in a        # Returns True
(x+1) in a    # Throws ValueError
1 in a        # Throws ValueError

我不知道这里发生了什么。有没有解决此问题的好方法?

我正在使用Python 3.7。

编辑:确切的错误是:

ValueError: The truth value of an array with more than one element is ambiguous.  Use a.any() or a.all()

我的numpy版本是1.18.1。

python arrays list numpy contains
4个回答
1
投票

[使用x in [1,x]时,python会将x与列表中的每个元素进行比较,在比较x == 1期间,结果将是一个numpy数组:

>>> x == 1
array([ True,  True])

并且将此数组解释为bool值将由于固有的歧义而触发错误:

>>> bool(x == 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

0
投票

您可以这样操作:

import numpy as np
x = np.array([1,1])
a = np.array([x.tolist(), 1])

x in a # True
(x+1) in a # False
1 in a # True

0
投票

原因是in或多或少被解释为

def in_sequence(elt, seq):
    for i in seq:
        if elt == i:
            return True
    return False

[1 == x不给出False,但引发异常,因为内部numpy将其转换为布尔数组。在大多数情况下它确实有意义,但是在这里却给出了愚蠢的行为。

听起来像个错误,但不容易修复。与1 == np.array(1, 1)相同地处理np.array(1, 1) == np.array(1, 1)是numpy的主要功能。将相等性比较委托给类是Python的一个主要功能。因此,我什至无法想象应该是正确的行为。

TL / DR:请不要混合使用Python列表和numpy数组,因为它们具有非常不同的语义,并且混合会导致不一致的特殊情况。


0
投票

(([[EDIT:包括更一般的方法,也许更干净的方法)]

解决它的一种方法是实现NumPy安全版本的in

import numpy as np def in_np(x, items): for item in items: if isinstance(x, np.ndarray) and isinstance(item, np.ndarray) \ and x.shape == item.shape and np.all(x == item): return True elif isinstance(x, np.ndarray) or isinstance(item, np.ndarray): pass elif x == item: return True return False

x = np.array([1, 1])
a = [x, 1]

for y in (x, 0, 1, x + 1, np.array([1, 1, 1])):
    print(in_np(y, a))
# True
# False
# True
# False
# False

或者甚至更好的是,编写具有任意比较的in版本(可能默认为默认的in行为),然后使用语义符合预期行为的np.array_equal()对于==。在代码中:

import operator def in_(x, items, eq=operator.eq): for item in items: if eq(x, item): return True return False

x = np.array([1, 1])
a = [x, 1]

for y in (x, 0, 1, x + 1, np.array([1, 1, 1])):
    print(in_(y, a, np.array_equal))
# True
# False
# True
# False
# False

最后,请注意,items可以是任何可迭代的,但是对于像O(1)这样的哈希容器,操作的复杂度不会是set(),尽管它仍会给出正确的结果:

print(in_(1, {1, 2, 3})) # True print(in_(0, {1, 2, 3})) # False in_(1, {1: 2, 3: 4}) # True in_(0, {1: 2, 3: 4}) # False

© www.soinside.com 2019 - 2024. All rights reserved.