ML ColumnTransformer OneHotEncoder

问题描述 投票:0回答:1

在数据帧的第一列中转换分类数据时,我发现 ColumnTransformer 和 OneHotEncoder 出现奇怪的行为。当我向 csv 文件添加一行时,就会发生这种情况。

初始数据为:

title,dailygross,theaters,DayInYear
ACatinParis,307,5,257
ALettertoMomo,307,5,257
AnotherDayofLife,307,5,257
ApprovedforAdoption,307,5,257
AprilandtheExtraordinaryWorld,307,5,257
Belle,307,5,257
BirdboyTheForgottenChildren,307,5,257
ChicoRita,307,5,257

运行代码时

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

dataset = pd.read_csv('../data/GKIDS_DayNum_test_names.csv')
dataset['title'].str.strip()
X = dataset.iloc[:, :-1].values
y = dataset.iloc[:, -1].values

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder

title_column_index = dataset.columns.get_loc('title')
print('title index:', title_column_index)
ct = ColumnTransformer(transformers=[('encoder', OneHotEncoder(), [title_column_index])], remainder='passthrough')
X_Encoded = np.array(ct.fit_transform(X))
print(X_Encoded)

结果正确:

[[1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 307 5]
 [0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 307 5]
 [0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 307 5]
 [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 307 5]
 [0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 307 5]
 [0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 307 5]
 [0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 307 5]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 307 5]]

但是,当我添加附加行时:

BlueGiant,307,5,257
到文件并重新运行代码我得到奇怪的输出:

  (0, 0)    1.0
  (0, 9)    307.0
  (0, 10)   5.0
  (1, 1)    1.0
  (1, 9)    307.0
  (1, 10)   5.0
  (2, 2)    1.0
  (2, 9)    307.0
  (2, 10)   5.0
  (3, 3)    1.0
  (3, 9)    307.0
  (3, 10)   5.0
  (4, 4)    1.0
  (4, 9)    307.0
  (4, 10)   5.0
  (5, 5)    1.0
  (5, 9)    307.0
  (5, 10)   5.0
  (6, 6)    1.0
  (6, 9)    307.0
  (6, 10)   5.0
  (7, 8)    1.0
  (7, 9)    307.0
  (7, 10)   5.0
  (8, 7)    1.0
  (8, 9)    307.0
  (8, 10)   5.0

我不明白为什么会这样。

请帮忙。

python machine-learning one-hot-encoding
1个回答
0
投票

后者是 scipy 稀疏数组的显示。

OneHotEncoder
生成稀疏数组作为输出,并且当其输出值的总体密度低于参数
ColumnTransformer
时,
sparse_threshold
使用稀疏数组。在第一个数据集中,密度恰好为 0.3,默认值
sparse_threshold
;在新数据集中添加新列后,它低于阈值 (3/11)。

© www.soinside.com 2019 - 2024. All rights reserved.