如何在 transformers 库中实现 `stopping_criteria` 参数?

问题描述 投票:0回答:1

我正在为

transformers
模型使用 python huggingface
text-generation
库。我需要知道如何在我正在使用的
stopping_criteria
函数中实现
generator()
参数。

我在本文档中找到了

stopping_criteria
参数: https://huggingface.co/transformers/main_classes/pipelines.html#transformers.TextGenerationPipeline

问题是,我只是不知道如何实现它。

我的代码:

from transformers import pipeline
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-125M')
stl = StoppingCriteria(['###'])
res = generator(prompt, do_sample=True,stopping_criteria = stl)
python generator documentation huggingface-transformers
1个回答
0
投票

这两种方法对我有用。

your_condition
是 True 当你想停下来的时候。

class CustomStoppingCriteria(StoppingCriteria):
    def __init__(self):
        pass
    
    def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
        return your_condition

stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria()])

def custom_stopping_criteria(input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
    return your_condition

stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])
© www.soinside.com 2019 - 2024. All rights reserved.