如何在 Pytorch 中向量化这 2 个循环(困难)

问题描述 投票:0回答:1

如何矢量化:

vocab_size = 20
batch_size = 2
input_len = 5
output_len = 10
input_ids = torch.randint(0, vocab_size, (batch_size, input_len))
output_ids = torch.randint(0, vocab_size, (batch_size, output_len))

print(input_ids)
print(output_ids)

tensor([[ 0,  8,  7, 12,  8],
        [14, 15,  9,  7, 10]])
tensor([[ 2, 8,  3, 15,  2, 19,  7,  1, 19,  8],
        [10,  8,  0,  7, 16,  0,  6,  2, 16, 13]])

基本上,output_ids中的新值将是batch_size + input_ids中该值的第k_个索引,因为该值也可能在input_ids和output_ids中出现多次。因此,如果该值在 output_ids 中第二次出现,则其将被 vocab_size + input_ids 中该值的第二个索引替换(尽管我上面的代码只出现了第一次)。我像示例一样更改输出的值(输出第一行的 21 和 24)

这就是我想要的:

#%%
for i in range(batch_size):
    for k, value in enumerate(output_ids[i]):
        if value in input_ids[i] and value not in [0, 1, 2]: # mean that I will ignore values 0, 1, 2
            output_ids[i][k] = vocab_size + torch.where(input_ids[i] == value)[0][0]
output_ids

tensor([[ 2, 21,  3, 15,  2, 19, 22,  1, 19, 24],
        [24,  8,  0, 23, 16,  0,  6,  2, 16, 13]])
python pytorch vectorization
1个回答
0
投票

我提出了这个垂直化版本,由于使用了 .clone(),因此在内存方面没有进行优化

output_ids_para = output_ids.clone()

## Equivalent to the reverse second condition of the if statement,
## that's why there is "~" notation below.
mask = (output_ids != 0) * (output_ids != 1) * (output_ids != 2)

## Replace with a value that does not appear in input_ids
output_ids_para[~mask] = vocab_size + 9999 

## Parallelize the comparison (first condition of the if statement + torch.where())
input_ids_expand = input_ids.unsqueeze(-1).expand(batch_size, input_len, output_len)
output_ids_expand = output_ids_para.unsqueeze(1).expand(batch_size, input_len, output_len)
indices_i, values, indices_k = torch.where(input_ids_expand == output_ids_expand) 

## The newly assigned output_ids
output_ids_para[indices_i, indices_k] = vocab_size + values

## Return the value that does not satisfy the second condition of the if statement
output_ids_para[~mask] =  output_ids[~mask] 

output_ids_para
© www.soinside.com 2019 - 2024. All rights reserved.