PyTorch中的消除和填充

问题描述 投票:1回答:1

假设我有一个这样的张量

w = [[0.1, 0.7, 0.7, 0.8, 0.3],
    [0.3, 0.2, 0.9, 0.1, 0.5],
    [0.1, 0.4, 0.8, 0.3, 0.4]]

现在,我想根据某些条件消除某些值(例如大于或小于0.5)

w = [[0.1, 0.3],
     [0.3, 0.2, 0.1],
     [0.1, 0.4, 0.3, 0.4]]

然后将其填充到相等的长度:

w = [[0.1, 0.3, 0, 0],
     [0.3, 0.2, 0.1, 0],
     [0.1, 0.4, 0.3, 0.4]]

这就是我在pytorch中实现它的方式:

w = torch.rand(3, 5)
condition = w <= 0.5
w = [w[i][condition[i]] for i in range(3)]
w = torch.nn.utils.rnn.pad_sequence(w)

但是显然,这将非常缓慢,主要是由于列表理解。有没有更好的方法呢?

python pytorch tensor
1个回答
0
投票

这是使用boolean maskingtensor splitting,然后最终使用torch.nn.utils.rnn.pad_sequence(...)填充分割张量的一种直接方法。

torch.nn.utils.rnn.pad_sequence(...)
© www.soinside.com 2019 - 2024. All rights reserved.