考虑以下文件:
import jax.numpy as jnp
def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
运行
mypy mypytest.py
返回以下错误:
mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")
出于某种原因,它相信添加两个
jax.numpy.ndarray
会返回一个由bools组成的NumPy数组。难道我做错了什么?或者这是 MyPy 或 Jax 类型注释中的错误?
至少静态地来说,
jnp.ndarray
是np.ndarray
的子类,修改非常少
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
因此,它继承了
np.ndarray
的方法类型签名。
我猜运行时行为是通过
jnp.array
函数实现的。除非我错过了一些存根文件或类型欺骗,否则 jnp.array
的结果与 jnp.ndarray
匹配只是因为 jnp.array
是无类型的。你可以用来测试一下
def foo(_: str) -> None:
pass
foo(jnp.array(0))
通过 mypy。
所以回答你的问题,我不认为你做错了什么。从某种意义上说,这是一个错误,它可能不是他们的意思,但实际上并不是不正确,因为当您添加
np.ndarray
时,您确实会得到 jnp.ndarray
,因为 jnp.ndarray
是 np.ndarray
。
至于为什么
bool
s,这可能是因为你的jnp.array
s缺少通用参数,并且__add__
上np.ndarray
的第一个有效重载是
@overload
def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
所以它只是默认为
bool
。
一般来说,JAX 与 mypy 的兼容性非常差,因为 JAX 的转换模型很难满足 mypy 的约束,该模型经常调用具有特定于转换的跟踪器值的函数,这些值充当数组的替代品(请参阅 如何思考JAX:JIT Mechanics 对此机制的简要讨论)。
使用跟踪器类型作为数组的代表意味着 mypy 在转换严格类型的 JAX 函数时会引发错误,因此在整个 JAX 代码库中,我们倾向于将
Array
别名为 Any
,并将其用作返回数组的 JAX 函数的返回类型注释。
对此进行改进会很好,因为
Any
返回类型对于有效的类型检查不是很有用,但这只是使 mypy 与 JAX 良好配合的众多挑战中的第一个。如果您想阅读过去几年有关此问题的一些值得讨论的内容,我将从这里开始:https://github.com/google/jax/issues/943
同时,我的建议是使用
Any
作为 JAX 数组的类型注释。
截至 2023 年末,
jax
似乎已大大改进了其打字注释。 mypy
可以使用 新语法:
from jax import Array
from jax.typing import ArrayLike
import jax.numpy as jnp
def test(a: ArrayLike, b: ArrayLike) -> Array:
return a + b