如何检查numpy数组是否在Python序列内?

问题描述 投票:0回答:1

我想检查给定数组是否在常规Python序列(列表,元组等)内。例如,考虑以下代码:

import numpy as np

xs = np.array([1, 2, 3])
ys = np.array([4, 5, 6])

myseq = (xs, 1, True, ys, 'hello')

我希望使用in进行简单的成员资格检查会起作用,例如:

>>> xs in myseq
True

但是如果我要查找的元素不在myseq的第一个位置,例如,显然会失败:

>>> ys in myseq
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

那么我该如何执行此检查?

[如果可能,我希望不必将myseq强制转换为numpy数组或任何其他类型的数据结构。

arrays python-3.x numpy sequence membership
1个回答
0
投票

这可能不是最美丽或禁忌的解决方案,但我认为它可以起作用:

import numpy as np


def array_in_tuple(array, tpl):
    i = 0
    while i < len(tpl):
        if isinstance(tpl[i], np.ndarray) and np.array_equal(array, tpl[i]):
            return True
        i += 1
    return False


xs = np.array([1, 2, 3])
ys = np.array([4, 5, 6])

myseq = (xs, 1, True, ys, 'hello')


print(array_in_tuple(xs, myseq), array_in_tuple(ys, myseq), array_in_tuple(np.array([7, 8, 9]), myseq))
© www.soinside.com 2019 - 2024. All rights reserved.