如何矢量化:
vocab_size = 20
batch_size = 2
input_len = 5
output_len = 10
input_ids = torch.randint(0, vocab_size, (batch_size, input_len))
output_ids = torch.randint(0, vocab_size, (batch_size, output_len))
print(input_ids)
print(output_ids)
tensor([[ 0, 8, 7, 12, 8],
[14, 15, 9, 7, 10]])
tensor([[ 2, 8, 3, 15, 2, 19, 7, 1, 19, 8],
[10, 8, 0, 7, 16, 0, 6, 2, 16, 13]])
基本上,output_ids中的新值将是batch_size + input_ids中该值的第k_个索引,因为该值也可能在input_ids和output_ids中出现多次。因此,如果该值在 output_ids 中第二次出现,则其将被 vocab_size + input_ids 中该值的第二个索引替换(尽管我上面的代码只出现了第一次)。我像示例一样更改输出的值(输出第一行的 21 和 24)
这就是我想要的:
#%%
for i in range(batch_size):
for k, value in enumerate(output_ids[i]):
if value in input_ids[i] and value not in [0, 1, 2]: # mean that I will ignore values 0, 1, 2
output_ids[i][k] = vocab_size + torch.where(input_ids[i] == value)[0][0]
output_ids
tensor([[ 2, 21, 3, 15, 2, 19, 22, 1, 19, 24],
[24, 8, 0, 23, 16, 0, 6, 2, 16, 13]])
我提出了这个垂直化版本,由于使用了 .clone(),因此在内存方面没有进行优化
output_ids_para = output_ids.clone()
## Equivalent to the reverse second condition of the if statement,
## that's why there is "~" notation below.
mask = (output_ids != 0) * (output_ids != 1) * (output_ids != 2)
## Replace with a value that does not appear in input_ids
output_ids_para[~mask] = vocab_size + 9999
## Parallelize the comparison (first condition of the if statement + torch.where())
input_ids_expand = input_ids.unsqueeze(-1).expand(batch_size, input_len, output_len)
output_ids_expand = output_ids_para.unsqueeze(1).expand(batch_size, input_len, output_len)
indices_i, values, indices_k = torch.where(input_ids_expand == output_ids_expand)
## The newly assigned output_ids
output_ids_para[indices_i, indices_k] = vocab_size + values
## Return the value that does not satisfy the second condition of the if statement
output_ids_para[~mask] = output_ids[~mask]
output_ids_para