我想在 mlx 中实现 masked_fill,但它与 float('-inf') 配合得不太好
https://pytorch.org/docs/stable/ generated/torch.Tensor.masked_fill.html
我正在尝试使用 mlx.core.where
masked_tensor = mlx.core.where(mask, mlx.core.array(float('-inf')), mlx.core.array(0))
但是为了面膜
array([[False, False, True, True],
[False, False, True, True],
[False, False, True, True],
[False, False, True, True]], dtype=bool)
这回来了
array([[nan, nan, -inf, -inf],
[nan, nan, -inf, -inf],
[nan, nan, -inf, -inf],
[nan, nan, -inf, -inf]], dtype=float32)
这不是我想要的。理想情况下它会回来
array([[0, 0, -inf, -inf],
[0, 0, -inf, -inf],
[0, 0, -inf, -inf],
[0, 0, -inf, -inf]], dtype=float32)
帮助
这应该适用于最新的 mlx:
pip install -U mlx
>>> import mlx.core as mx
>>> mask = mx.array([True, False])
>>> mx.where(mask, mx.array(float("-inf")), mx.array(0.0))
array([-inf, 0], dtype=float32)