我正在使用 transformers_interpret 库来解释单词归因,以在自然语言处理任务中构建微调模型。我想打印所有归因分数及其可视化,以了解文本的累积归因分数超过指定阈值的情况。
这是我的代码的简化版本:
# Set the desired attribution score threshold
attribution_threshold = 2.50
# Set the desired maximum sequence length
max_sequence_length = 512
for row_index in range(len(df)):
text = df.loc[row_index, 'TEXT']
truncated_text = text[:max_sequence_length]
word_attributions = cls_explainer(truncated_text)
# Check if the overall attribution score is above the threshold
overall_score = sum(abs(attribution) for attribution in word_attributions)
if overall_score > attribution_threshold:
cls_explainer.visualize()
但是,我遇到了一个问题,其中 word_attributions 包含元组而不是每个单词的单独分数。我的目标是修改此代码,以便我可以正确处理元组值并实现对个人归因分数的所需检查。
在每个元组中的任何单个归因分数超过指定阈值的情况下,如何调整代码以打印所有归因分数及其可视化?
我找到了一个解决方案。问题是文档有点乱。
这是解决该问题的代码:
# Set the desired maximum sequence length
max_sequence_length = 512
attribution_threshold = 1.8 # Adjust the threshold as needed
# List to store word attributions for each row
all_word_attributions = []
# Loop through all rows in the dataset
for row_index in range(len(df)):
text = df.loc[row_index, 'TEXT'] # Assuming 'TEXT' is the correct column name
truncated_text = text[:max_sequence_length]
word_attributions = cls_explainer(truncated_text)
# Sum of attributions for the entire text
total_attribution = cls_explainer.attributions.attributions_sum.sum()
# Print the sum of attributions
print(f'Total Attribution for Row {row_index}: {total_attribution}')
# Save the filtered attributions for this row
all_word_attributions.append(word_attributions)
# Visualize if the total attribution is above the threshold
if abs(total_attribution) > attribution_threshold:
cls_explainer.visualize()
如您所见,
total_attribution = cls_explainer.attributions.attributions_sum.sum()
是完成这项工作的部分。
希望这有用。