输入 numpy 数组子类的 `__get_item__` 方法

问题描述 投票:0回答:1

让我们考虑

numpy
ndarray
类的子类:

import numpy as np

class ArraySubClass(np.ndarray):
    def __new__(cls, input_array: np.ndarray):
        obj = np.asarray(input_array).view(cls)
        return obj

然后,获取

ArraySubClass
的切片返回相同类型的对象,如 文档中所述:

>>> type(ArraySubClass(np.zeros((3, 3)))[:, 0])
<class '__main__.ArraySubClass'>

到目前为止一切顺利,但是当我使用检查的静态类型时,我开始出现意想不到的行为

pyright
,如下面的示例所示:

def f(x: ArraySubClass):
    print(x)

f(ArraySubClass(np.zeros((3, 3)))[:, 0])

最后一行从

pyright
引发错误:

 Argument of type "ndarray[Any, Unknown]" cannot be assigned to parameter "x" of type "ArraySubClass" in function "f"
  "ndarray[Any, Unknown]" is incompatible with "ArraySubClass"

是什么原因导致这种行为?这是一个

pyright
错误吗?类型提示给出的
__get_item__
中的
np.ndarray
方法的签名是否不正确?或者也许我应该用正确的签名覆盖
ArraySubClass
中的这个方法?

python numpy python-typing pyright
1个回答
0
投票

您的测试:

>>> type(ArraySubClass(np.zeros((3, 3)))[:, 0])
<class '__main__.ArraySubClass'>

表明

ArraySubClass(np.zeros((3, 3)))[:, 0]
的具体结果确实是
ArraySubClass
,因此您认为这不会造成麻烦的期望是有道理的:

def f(x: ArraySubClass):
    print(x)

f(ArraySubClass(np.zeros((3, 3)))[:, 0])

但是,

f
特别需要一个
ArraySubClass
,并且 linter 不会运行代码来查看
x
最终是什么。它必须依靠类型提示来弄清楚
x
could 是什么。

由于

x
ArraySubClass(np.zeros((3, 3)))[:, 0]
的结果,这意味着您在
__get_item__
上调用
ArraySubClass(np.zeros((3, 3)))
,其定义为(对于
ndarray
):

def __getitem__(self, key: (
    NDArray[integer[Any]]
    | NDArray[bool_]
    | tuple[NDArray[integer[Any]] | NDArray[bool_], ...]
)) -> ndarray[Any, _DType_co]: ...

所以,

pyright
所知道的是,
ndarray
将会出现。这是事实,但是这个
ndarray
也是一个
ArraySubClass
的信息丢失了,
pyright
无法分辨。

就像用户 @barmar 在评论中指出的那样,如果

__get_item__
使用泛型实现,则可以避免这种情况,从而让 Python(和 linter)清楚地知道传入的是
NDArray
不仅是真的出来的是
ndarray[Any, _DType_co]
,但实际上出来的是 相同类型的
NDArray

但事实并非如此,这是有充分理由的。 (我不是在发表意见,我只是说我还没有更深入地研究代码或文档来推测是否确实有充分的理由)

© www.soinside.com 2019 - 2024. All rights reserved.