具有特定形状和数据类型的 Numpy 类型

问题描述 投票:0回答:3

目前我正在尝试更多地使用 numpy 类型来使我的代码更清晰,但是我已经达到了目前无法覆盖的限制。

是否可以指定特定的形状以及相应的数据类型? 示例:

Shape=(4,)
datatype= np.int32

到目前为止,我的尝试如下所示(但都只是抛出错误):

第一次尝试:

import numpy as np

def foo(x: np.ndarray[(4,), np.dtype[np.int32]]):
...
result -> 'numpy._DTypeMeta' object is not subscriptable

第二次尝试:

import numpy as np
import numpy.typing as npt

def foo(x: npt.NDArray[(4,), np.int32]):
...
result -> Too many arguments for numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]]

此外,不幸的是,我在文档中找不到有关它的任何信息,或者只有当我按照文档记录的方式实现它时才会出现错误。

numpy python-typing typing
3个回答
29
投票

目前,

numpy.typing.NDArray
仅接受dtype,如下所示:
numpy.typing.NDArray[numpy.int32]
。不过你有一些选择。

使用
typing.Annotated

typing.Annotated
允许您为类型创建别名并与其捆绑一些额外信息。

在某些

my_types.py
中,您可以写下您想要提示的形状的所有变体:

from typing import Annotated, Literal, TypeVar
import numpy as np
import numpy.typing as npt


DType = TypeVar("DType", bound=np.generic)

Array4 = Annotated[npt.NDArray[DType], Literal[4]]
Array3x3 = Annotated[npt.NDArray[DType], Literal[3, 3]]
ArrayNxNx3 = Annotated[npt.NDArray[DType], Literal["N", "N", 3]]

然后在

foo.py
中,您可以提供 numpy dtype 并将它们用作类型提示:

import numpy as np
from my_types import Array4


def foo(arr: Array4[np.int32]):
    assert arr.shape == (4,)

MyPy 会将

arr
识别为
np.ndarray
并对其进行检查。形状检查只能在运行时完成,就像本例中的
assert

如果你不喜欢这个断言,你可以发挥你的创造力来定义一个函数来为你做检查。

def assert_match(arr, array_type):
    hinted_shape = array_type.__metadata__[0].__args__
    hinted_dtype_type = array_type.__args__[0].__args__[1]
    hinted_dtype = hinted_dtype_type.__args__[0]
    assert np.issubdtype(arr.dtype, hinted_dtype), "DType does not match"
    assert arr.shape == hinted_shape, "Shape does not match"


assert_match(some_array, Array4[np.int32])

使用
nptyping

另一种选择是使用第 3 方库

nptyping
(是的,我是作者)。

你会放弃

my_types.py
,因为它不再有用了。

你的

foo.py
会变成这样:

from nptyping import NDArray, Shape, Int32


def foo(arr: NDArray[Shape["4"], Int32]):
    assert isinstance(arr, NDArray[Shape["4"], Int32])

使用
beartype
+
typing.Annotated

还有另一个名为

beartype
的第三方库可供您使用。它可以采用
typing.Annotated
方法的变体,并为您进行运行时检查。

您将恢复您的

my_types.py
,内容类似于:

from beartype import beartype
from beartype.vale import Is
from typing import Annotated
import numpy as np


Int32Array4 = Annotated[np.ndarray, Is[lambda array:
    array.shape == (4,) and np.issubdtype(array.dtype, np.int32)]]
Int32Array3x3 = Annotated[np.ndarray, Is[lambda array:
    array.shape == (3,3) and np.issubdtype(array.dtype, np.int32)]]

你的

foo.py
将变成:

import numpy as np
from beartype import beartype
from my_types import Int32Array4 


@beartype
def foo(arr: Int32Array4):
    ...  # Runtime type checked by beartype.

使用
beartype
+
nptyping

您还可以堆叠两个库。

您的

my_types.py
可以再次删除,您的
foo.py
将变成类似以下内容:

from nptyping import NDArray, Shape, Int32
from beartype import beartype


@beartype
def foo(arr: NDArray[Shape["4"], Int32]):
    ...  # Runtime type checked by beartype.

1
投票

我过去常常这样进行:

def foo(x):
    x = np.array(x, dtype=np.int32)
    if x.shape!=Shape:
        raise ValueError("Shape mismatch")
    #...

如果您对形状有特定问题,您应该根据您期望的输入形状重新调整形状。如果您需要帮助来正确地重塑它,请提供您的输入示例

x


0
投票

尝试 JaxTyping https://github.com/google/jaxtyping/tree/main 不仅适用于 numpy,还适用于 PyTorch 和 Tensorflow。

© www.soinside.com 2019 - 2024. All rights reserved.