假设我有一个这样的张量
w = [[0.1, 0.7, 0.7, 0.8, 0.3],
[0.3, 0.2, 0.9, 0.1, 0.5],
[0.1, 0.4, 0.8, 0.3, 0.4]]
现在,我想根据某些条件消除某些值(例如大于或小于0.5)
w = [[0.1, 0.3],
[0.3, 0.2, 0.1],
[0.1, 0.4, 0.3, 0.4]]
然后将其填充到相等的长度:
w = [[0.1, 0.3, 0, 0],
[0.3, 0.2, 0.1, 0],
[0.1, 0.4, 0.3, 0.4]]
这就是我在pytorch中实现它的方式:
w = torch.rand(3, 5)
condition = w <= 0.5
w = [w[i][condition[i]] for i in range(3)]
w = torch.nn.utils.rnn.pad_sequence(w)
但是显然,这将非常缓慢,主要是由于列表理解。有没有更好的方法呢?
这是使用boolean masking,tensor splitting,然后最终使用torch.nn.utils.rnn.pad_sequence(...)
填充分割张量的一种直接方法。
torch.nn.utils.rnn.pad_sequence(...)