如何防止transformer generate函数产生某些词?

问题描述 投票:0回答:0

我有以下代码

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids

sequence_ids = model.generate(input_ids)
sequences = tokenizer.batch_decode(sequence_ids)
sequences

目前它产生这个:

['<pad><extra_id_0> park offers<extra_id_1> the<extra_id_2> park.</s>']

有没有办法阻止生成器生成不在列表中的某些词(例如 park、offer)?

python huggingface-transformers generative-pretrained-transformer
© www.soinside.com 2019 - 2024. All rights reserved.