从BertForSequenceClassification获取特征向量

问题描述 投票:0回答:3

我已经成功地使用BertForSequenceClassification中的huggingface/transformers构建了情感分析工具,可以将$ tsla推文分类为正面还是负面。

但是,我无法从微调模型中获得每条推文的特征向量(更具体地说是[CLS]的嵌入。

二手型号的更多信息:

model = BertForSequenceClassification.from_pretrained(OUTPUT_DIR, num_labels=num_labels)
model.config.output_hidden_states = True
tokenizer = BertTokenizer(OUTPUT_DIR+'vocab.txt')

但是,当我运行output变量下面的代码时,仅包含logits。

model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []

for input_ids, input_mask, segment_ids, label_ids in tqdm_notebook(eval_dataloader, desc="Evaluating"):
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    label_ids = label_ids.to(device)

    with torch.no_grad():
        output = model(input_ids,token_type_ids= segment_ids,attention_mask= input_mask)
pytorch embedding bert
3个回答
0
投票

[BertForSequenceClassification是一个包装器,由两个部分组成:BERT模型(属性bert)和分类器(属性classifier)。

您可以直接调用基础的BERT模型。如果直接将输入传递给它,则将获得隐藏状态。它返回一个元组:该元组的第一个成员都是隐藏状态,第二个是[CLS]向量。


0
投票

使用令牌生成器获取令牌的ID。在[CLS]的特殊情况下>

>>> tokenizer.cls_token_id
101

将会。更一般地说,是

>>> tokenizer.encode('my text')
[2026, 3793]

获得mytext的ID。

然后,一个好主意是打印模型实例以查看其结构:

>>> print(model)
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1)
    )
...

因此,您只需键入model.bert.embeddings.word_embeddings即可获取嵌入层。这是基本的Embedding layer,可以通过weight属性获得其嵌入矩阵(简单的割炬张量)。使用令牌ID对其进行索引将为您提供向量。因此,总的来说,您只需要这样做:

>>> model.bert.embeddings.word_embeddings.weight[101]
tensor([ 1.3630e-02, -2.6490e-02, -2.3503e-02, -7.7876e-03,  8.5892e-03,
        -7.6645e-03, -9.8808e-03,  6.0184e-03,  4.6921e-03, -3.0984e-02,
         1.8883e-02, -6.0093e-03, -1.6652e-02,  1.1684e-02, -3.6245e-02,
         8.3482e-03, -1.2112e-03,  1.0322e-02,  1.6692e-02, -3.0354e-02,
         ...

0
投票

在对BertForSequenceClassification进行微调后,我也遇到了这个问题。我知道您的目的是获取[CLS]的隐藏状态作为每个tweet的表示。对?按照API document的说明,我认为代码是:

© www.soinside.com 2019 - 2024. All rights reserved.