Pipeline.fit() 为参数“text_input”获取了多个值

问题描述 投票:0回答:1

我想用promptify进行多标签文本分类。

我收到此错误

`Pipeline.fit() got multiple values for argument 'text_input'`

我的代码:

model = OpenAI(api_key)
prompter = Prompter('multilabel_classification.jinja')
pipe = Pipeline(prompter, model)

classes = ['Medicine','Oncology','Metastasis','Breast cancer','Lung cancer','Cerebrospinal fluid','Tumor microenvironment','Single-cell RNA sequencing','Idiopathic intracranial hypertension']

sent = "The patient is a 93-year-old female with a medical history of chronic right hip pain, osteoporosis, hypertension, depression, and chronic atrial fibrillation admitted for evaluation and management of severe nausea and vomiting and urinary tract infection"
result = pipe.fit('multilabel_classification.jinja',
n_output_labels = len(classes),
domain = 'Clinical',
text_input = sent,
labels = classes)

print(eval(result['text']))

感谢您花时间回复我。

祝你有美好的一天

prompt
1个回答
0
投票

出现此错误的原因是因为

Pipeline.fit()
函数默认将第一个参数视为text_input。因此,字符串
'multilabel_classification.jinja'
被分配给 text_input 参数。你的函数应该是这样的:

result = pipe.fit(n_output_labels = len(classes), 
                  domain = 'Clinical', 
                  text_input = sent, labels = classes)

请看源码:

© www.soinside.com 2019 - 2024. All rights reserved.