HuggingFace Transformers Trainer._maybe_log_save_evaluate IndexError:标量变量的索引无效

问题描述 投票:0回答:1

所以,我正在为问题生成微调 BART 模型,它似乎正在接受培训。然后突然间,它在第一次验证结束时停止,带有

IndexError
,您可以在下面看到。问题发生在正在调用的
Trainer._maybe_log_save_evaluate
方法中。

这是我设置模型、分词器、数据集等的代码:

from datasets import load_dataset
from evaluate import load
from accelerate import Accelerator
from transformers import BartForConditionalGeneration, BartConfig, BartTokenizer
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer 

dataset = load_dataset("squad")
metric = load("squad")
accelerator = Accelerator()

def model_init():
  config = BartConfig()
  return accelerator.prepare(BartForConditionalGeneration(config).from_pretrained("facebook/bart-base").cuda())

tokenizer = accelerator.prepare(BartTokenizer.from_pretrained("facebook/bart-base"))

def preprocess_function(data):
  inputs = tokenizer(data['context'], add_special_tokens=True, max_length=256, padding="max_length", truncation=True)
  targets = tokenizer(data['question'], add_special_tokens=True, max_length=32, padding="max_length", truncation=True)
  return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'labels': targets['input_ids']}

dataset = dataset.map(preprocess_function, batched=True).shuffle(seed=777)

training_args = Seq2SeqTrainingArguments(
  output_dir="./results",
  evaluation_strategy="steps",
  eval_steps=500,
  save_steps=50000,
  learning_rate=2e-5,
  per_device_train_batch_size=4,
  per_device_eval_batch_size=4,
  num_train_epochs=2,
  weight_decay=0.01,
  predict_with_generate=True,
)

def compute_metrics(eval_pred):
  predictions, labels = eval_pred
  predictions = predictions.argmax(axis=-1)
  return metric.compute(predictions=predictions, references=labels)

trainer = Seq2SeqTrainer(
  args=training_args,
  train_dataset=dataset["train"],
  eval_dataset=dataset["validation"],
  tokenizer=tokenizer,
  model_init=model_init,
  compute_metrics=compute_metrics,
)

trainer.train()

我似乎无法弄清楚为什么会这样,而且我在网上找到的任何东西都没有帮助。

python pytorch nlp huggingface-transformers huggingface
1个回答
0
投票

您的问题来自您的

compute_metrics
函数,因为您正在使用带有文本生成模型的 QA 指标。

要修复它,请将

metric = load("squad")
替换为文本生成指标,例如 bleu:
metric = load("bleu")
。并因此调整您的
compute_metrics
功能:

def compute_metrics(eval_pred):
    predictions, references = eval_pred
    predictions = tokenizer.batch_decode(predictions)
    references = tokenizer.batch_decode(references)
    references = [[ref] for ref in references]
    return metric.compute(predictions=predictions, references=references)
© www.soinside.com 2019 - 2024. All rights reserved.