从T5ForConditionalGeneration确定decoder_hidden_states的内容

问题描述 投票:0回答:1

我正在使用 Huggingface

T5ForConditionalGeneration
模型,无需修改。

我想计算 T5 解码器最后一个隐藏状态的平均池化,但我无法确定

decoder_hidden_states
的哪一部分包含我正在寻找的内容。

我想做这样的事情:

# Prepare batch data
sources = batch_df['Source'].tolist()
tokenized_input = self.tokenizer(sources, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length).to('cuda')
input_ids = tokenized_input['input_ids'].to('cuda')
attention_mask = tokenized_input['attention_mask'].to('cuda')

input_batch = {
    'input_ids': input_ids, 
    'attention_mask': attention_mask,
    'do_sample': False,
    'num_beams': 1,
    'eos_token_id': self.tokenizer.eos_token_id,
    'pad_token_id': self.tokenizer.pad_token_id,
    'max_length': self.max_output_length,
    'output_scores': True,
    'return_dict_in_generate': True,
    'output_hidden_states': True,
}
outputs = self.model.generate(**input_batch)

# Retrieve the decoder hidden states
decoder_last_hidden_state = outputs.decoder_hidden_states[-1]  # Last layer's hidden states

# Compute the mean of the hidden states across the sequence length dimension
mean_pooled_output = torch.mean(decoder_last_hidden_state, dim=1, keepdim=False)

这种方法适用于编码器,但对于解码器来说,

decoder_hidden_states[-1]
是张量的元组,而不是张量。

当我第一次检查元组时,有 10 个元组,每个元组包含 7 个张量。

当我检查尺寸时,像这样:

for tuple_number in range(n):  # Checking the tuples
    print(f"Tuple {layer_number}:")
    for i, tensor in enumerate(outputs.decoder_hidden_states[layer_number]):
        print(f"  Tuple {i} in Layer {layer_number}: shape {tensor.shape}")

输出都是这样的:

Tuple 0:
  Tensor 0 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 1 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 2 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 3 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 4 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 5 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 6 in Tuple 0: shape torch.Size([2, 1, 512])
Tuple 1:
  Tensor 0 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 1 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 2 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 3 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 4 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 5 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 6 in Tuple 1: shape torch.Size([2, 1, 512])
. . .

512 是我的分词器的 max_length,2 是我的批量大小。 (我验证了 2 是批量大小,因为当我修改批量大小时该数字发生了变化。)

然后,当我将输入字符串的长度修剪为 10 个字符时,令我惊讶的是,元组的数量从 10 个变为 39 个。当我将字符串进一步修剪为每个字符串仅 2 个字符时,元组的数量并没有减少。增加到39以上。 然后,当我将输入字符串长度加倍时,元组的数量下降到 7。因此,元组的数量似乎对应于解码器在某些块大小(达到某些限制)上的迭代。

因此,如果我想计算第一个标记的均值池,似乎我会计算第一个元组的最后一个张量的均值。 但是,我不明白令牌长度如何对应于元组的数量。

如何确定这些元组和张量到底代表什么?我通过查看 T5 源代码未能成功找到此信息。

pytorch nlp huggingface-transformers
1个回答
0
投票

我认为发生的情况是 T5 在解码的每一步都返回隐藏状态。因此,元组的数量应该与最长的生成序列相对应。您很可能对最后一个解码步骤感兴趣,并且可以采用最后一个元组。

在该元组中,您有一个大小为 num_layers + 1 的元组(+1 表示最终的 LayerNorm)。最后一层的输出应该是最后一个元组条目。

© www.soinside.com 2019 - 2024. All rights reserved.