sklearn预处理中输出列的保持轨迹

问题描述 投票:1回答:1

如何跟踪由sklearn.compose.ColumnTransformer产生的转换数组的列? “保持跟踪”是指执行逆变换所需的每一位信息都必须显式显示。这至少包括以下内容:

  1. 输出数组中每一列的源变量是什么?
  2. 如果输出数组的一列来自分类变量的一键编码,那是什么类别?
  3. 每个变量的确切推算值是多少?
  4. 用于标准化每个数字变量的(平均值,标准偏差)是什么? (由于估算的缺失值,这些可能与直接计算有所不同。)

我正在基于this answer使用相同的方法。我的输入数据集也是具有多个数值和分类列的通用pandas.DataFrame。是的,该答案可以转换原始数据集。但是我没有跟踪输出数组中的列。我需要这些信息用于同行评审,报告撰写,演示和进一步的模型构建步骤。我一直在寻找系统的方法,但是没有运气。

python scikit-learn preprocessor
1个回答
0
投票

提到的答案是基于Sklearn中的this

您可以使用以下代码段获取前两个问题的答案。

def get_feature_names(columnTransformer):

    output_features = []

    for name, pipe, features in columnTransformer.transformers_:
        if name!='remainder':
            for i in pipe:
                trans_features = []
                if hasattr(i,'categories_'):
                    trans_features.extend(i.get_feature_names(features))
                else:
                    trans_features = features
            output_features.extend(trans_features)

    return output_features
import pandas as pd
pd.DataFrame(preprocessor.fit_transform(X_train),
            columns=get_feature_names(preprocessor))

enter image description here

transformed_cols = get_feature_names(preprocessor)

def get_original_column(col_index):
    return transformed_cols[col_index].split('_')[0]

get_original_column(3)
# 'embarked'

get_original_column(0)
# 'age'
def get_category(col_index):
    new_col = transformed_cols[col_index].split('_')
    return 'no category' if len(new_col)<2 else new_col[-1]

print(get_category(3))
# 'Q'

print(get_category(0))
# 'no category'

对于当前版本的Sklearn来说,跟踪某个功能是否进行了插补或缩放并不容易。

© www.soinside.com 2019 - 2024. All rights reserved.