处理 Transformers Interpret 的词属性中的元组值

问题描述 投票:0回答:1

我正在使用 transformers_interpret 库来解释单词归因,以在自然语言处理任务中构建微调模型。我想打印所有归因分数及其可视化,以了解文本的累积归因分数超过指定阈值的情况。

这是我的代码的简化版本:

# Set the desired attribution score threshold
attribution_threshold = 2.50

# Set the desired maximum sequence length
max_sequence_length = 512

for row_index in range(len(df)):
    text = df.loc[row_index, 'TEXT'] 
    truncated_text = text[:max_sequence_length]
    word_attributions = cls_explainer(truncated_text)

    # Check if the overall attribution score is above the threshold
    overall_score = sum(abs(attribution) for attribution in word_attributions)
    if overall_score > attribution_threshold:
        cls_explainer.visualize()

但是,我遇到了一个问题,其中 word_attributions 包含元组而不是每个单词的单独分数。我的目标是修改此代码,以便我可以正确处理元组值并实现对个人归因分数的所需检查。

在每个元组中的任何单个归因分数超过指定阈值的情况下,如何调整代码以打印所有归因分数及其可视化?

tuples huggingface-transformers interpreter attribution
1个回答
0
投票

我找到了一个解决方案。问题是文档有点乱。

这是解决该问题的代码:

# Set the desired maximum sequence length
max_sequence_length = 512
attribution_threshold = 1.8  # Adjust the threshold as needed

# List to store word attributions for each row
all_word_attributions = []

# Loop through all rows in the dataset
for row_index in range(len(df)):
    text = df.loc[row_index, 'TEXT']  # Assuming 'TEXT' is the correct column name
    truncated_text = text[:max_sequence_length]
    word_attributions = cls_explainer(truncated_text)

    # Sum of attributions for the entire text
    total_attribution = cls_explainer.attributions.attributions_sum.sum()

    # Print the sum of attributions
    print(f'Total Attribution for Row {row_index}: {total_attribution}')

    # Save the filtered attributions for this row
    all_word_attributions.append(word_attributions)

    # Visualize if the total attribution is above the threshold
    if abs(total_attribution) > attribution_threshold:
        cls_explainer.visualize()

如您所见,

total_attribution = cls_explainer.attributions.attributions_sum.sum()

是完成这项工作的部分。

希望这有用。

© www.soinside.com 2019 - 2024. All rights reserved.