目前我正在尝试更多地使用 numpy 类型来使我的代码更清晰,但是我已经达到了目前无法覆盖的限制。
是否可以指定特定的形状以及相应的数据类型? 示例:
Shape=(4,)
datatype= np.int32
到目前为止,我的尝试如下所示(但都只是抛出错误):
第一次尝试:
import numpy as np
def foo(x: np.ndarray[(4,), np.dtype[np.int32]]):
...
result -> 'numpy._DTypeMeta' object is not subscriptable
第二次尝试:
import numpy as np
import numpy.typing as npt
def foo(x: npt.NDArray[(4,), np.int32]):
...
result -> Too many arguments for numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]]
此外,不幸的是,我在文档中找不到有关它的任何信息,或者只有当我按照文档记录的方式实现它时才会出现错误。
目前,
numpy.typing.NDArray
仅接受dtype,如下所示:numpy.typing.NDArray[numpy.int32]
。不过你有一些选择。
typing.Annotated
typing.Annotated
允许您为类型创建别名并与其捆绑一些额外信息。
在某些
my_types.py
中,您可以写下您想要提示的形状的所有变体:
from typing import Annotated, Literal, TypeVar
import numpy as np
import numpy.typing as npt
DType = TypeVar("DType", bound=np.generic)
Array4 = Annotated[npt.NDArray[DType], Literal[4]]
Array3x3 = Annotated[npt.NDArray[DType], Literal[3, 3]]
ArrayNxNx3 = Annotated[npt.NDArray[DType], Literal["N", "N", 3]]
然后在
foo.py
中,您可以提供 numpy dtype 并将它们用作类型提示:
import numpy as np
from my_types import Array4
def foo(arr: Array4[np.int32]):
assert arr.shape == (4,)
MyPy 会将
arr
识别为 np.ndarray
并对其进行检查。形状检查只能在运行时完成,就像本例中的 assert
。
如果你不喜欢这个断言,你可以发挥你的创造力来定义一个函数来为你做检查。
def assert_match(arr, array_type):
hinted_shape = array_type.__metadata__[0].__args__
hinted_dtype_type = array_type.__args__[0].__args__[1]
hinted_dtype = hinted_dtype_type.__args__[0]
assert np.issubdtype(arr.dtype, hinted_dtype), "DType does not match"
assert arr.shape == hinted_shape, "Shape does not match"
assert_match(some_array, Array4[np.int32])
nptyping
nptyping
(是的,我是作者)。
你会放弃
my_types.py
,因为它不再有用了。
你的
foo.py
会变成这样:
from nptyping import NDArray, Shape, Int32
def foo(arr: NDArray[Shape["4"], Int32]):
assert isinstance(arr, NDArray[Shape["4"], Int32])
beartype
+ typing.Annotated
beartype
的第三方库可供您使用。它可以采用 typing.Annotated
方法的变体,并为您进行运行时检查。
您将恢复您的
my_types.py
,内容类似于:
from beartype import beartype
from beartype.vale import Is
from typing import Annotated
import numpy as np
Int32Array4 = Annotated[np.ndarray, Is[lambda array:
array.shape == (4,) and np.issubdtype(array.dtype, np.int32)]]
Int32Array3x3 = Annotated[np.ndarray, Is[lambda array:
array.shape == (3,3) and np.issubdtype(array.dtype, np.int32)]]
你的
foo.py
将变成:
import numpy as np
from beartype import beartype
from my_types import Int32Array4
@beartype
def foo(arr: Int32Array4):
... # Runtime type checked by beartype.
beartype
+ nptyping
您还可以堆叠两个库。
您的
my_types.py
可以再次删除,您的 foo.py
将变成类似以下内容:
from nptyping import NDArray, Shape, Int32
from beartype import beartype
@beartype
def foo(arr: NDArray[Shape["4"], Int32]):
... # Runtime type checked by beartype.
我过去常常这样进行:
def foo(x):
x = np.array(x, dtype=np.int32)
if x.shape!=Shape:
raise ValueError("Shape mismatch")
#...
如果您对形状有特定问题,您应该根据您期望的输入形状重新调整形状。如果您需要帮助来正确地重塑它,请提供您的输入示例
x
。
尝试 JaxTyping https://github.com/google/jaxtyping/tree/main 不仅适用于 numpy,还适用于 PyTorch 和 Tensorflow。