如何在mlx中进行蒙版填充?

问题描述 投票:0回答:1

我想在 mlx 中实现 masked_fill,但它与 float('-inf') 配合得不太好

https://pytorch.org/docs/stable/ generated/torch.Tensor.masked_fill.html

我正在尝试使用 mlx.core.where

masked_tensor = mlx.core.where(mask, mlx.core.array(float('-inf')), mlx.core.array(0))

但是为了面膜

array([[False, False, True, True],                                                                                                                                                      
       [False, False, True, True],                                                                                                                                                     
       [False, False, True, True],                                                                                                                                                    
       [False, False, True, True]], dtype=bool) 

这回来了

array([[nan, nan, -inf, -inf],                                                                                                                                              
       [nan, nan, -inf, -inf],                                                                                                                                                
       [nan, nan, -inf, -inf],                                                                                            
       [nan, nan, -inf, -inf]], dtype=float32)   

这不是我想要的。理想情况下它会回来

array([[0, 0, -inf, -inf],                                                                                                                                              
       [0, 0, -inf, -inf],                                                                                                                                                
       [0, 0, -inf, -inf],                                                                                            
       [0, 0, -inf, -inf]], dtype=float32)  

帮助

python machine-learning pytorch tensor
1个回答
0
投票

这应该适用于最新的 mlx:

pip install -U mlx

>>> import mlx.core as mx
>>> mask = mx.array([True, False])
>>> mx.where(mask, mx.array(float("-inf")), mx.array(0.0))
array([-inf, 0], dtype=float32)
© www.soinside.com 2019 - 2024. All rights reserved.