所以,我正在为问题生成微调 BART 模型,它似乎正在接受培训。然后突然间,它在第一次验证结束时停止,带有
IndexError
,您可以在下面看到。问题发生在正在调用的Trainer._maybe_log_save_evaluate
方法中。
这是我设置模型、分词器、数据集等的代码:
from datasets import load_dataset
from evaluate import load
from accelerate import Accelerator
from transformers import BartForConditionalGeneration, BartConfig, BartTokenizer
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
dataset = load_dataset("squad")
metric = load("squad")
accelerator = Accelerator()
def model_init():
config = BartConfig()
return accelerator.prepare(BartForConditionalGeneration(config).from_pretrained("facebook/bart-base").cuda())
tokenizer = accelerator.prepare(BartTokenizer.from_pretrained("facebook/bart-base"))
def preprocess_function(data):
inputs = tokenizer(data['context'], add_special_tokens=True, max_length=256, padding="max_length", truncation=True)
targets = tokenizer(data['question'], add_special_tokens=True, max_length=32, padding="max_length", truncation=True)
return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'labels': targets['input_ids']}
dataset = dataset.map(preprocess_function, batched=True).shuffle(seed=777)
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
evaluation_strategy="steps",
eval_steps=500,
save_steps=50000,
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=2,
weight_decay=0.01,
predict_with_generate=True,
)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = predictions.argmax(axis=-1)
return metric.compute(predictions=predictions, references=labels)
trainer = Seq2SeqTrainer(
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
)
trainer.train()
我似乎无法弄清楚为什么会这样,而且我在网上找到的任何东西都没有帮助。
您的问题来自您的
compute_metrics
函数,因为您正在使用带有文本生成模型的 QA 指标。
要修复它,请将
metric = load("squad")
替换为文本生成指标,例如 bleu:metric = load("bleu")
。并因此调整您的 compute_metrics
功能:
def compute_metrics(eval_pred):
predictions, references = eval_pred
predictions = tokenizer.batch_decode(predictions)
references = tokenizer.batch_decode(references)
references = [[ref] for ref in references]
return metric.compute(predictions=predictions, references=references)