我正在尝试绘制我训练的 XGBoost 模型的 SHAP 分析图。与this类似的东西。
但是,我使用了 Dart booster,所以
shap.TreeExplainer
不起作用。然后,我尝试使用 shap.KernelExplainer
,它应该对我有用。但是,它不接受任何常见类型的输入。
我的代码是这样的:
第一次尝试
# Data to predict
full_data = xgb.DMatrix(full_X, label=full_y, feature_names=feature_names)
# Pre-trained XGB model using DART booster
loaded_model.set_param({"device": "cuda"})
xgb_predict = lambda x: loaded_model.predict(x)
explainer = shap.KernelExplainer(xgb_predict, full_data)
我明白了:
TypeError: Unknown type passed as data object: <class 'xgboost.core.DMatrix'>
第二次尝试
我还尝试提供一个 numpy 数组:
X_np = np.array(full_X)
explainer = shap.KernelExplainer(xgb_predict, X_np)
但它也返回一个错误:
TypeError: ('Expecting data to be a DMatrix object, got: ', <class 'numpy.ndarray'>)
我正在使用 shap 0.44.0 和 xgboost 2.0.2
如何解决这个问题?
到底发生了什么
如果其他人遇到这个问题,这是我发现的:
shap.KernelExplainer
尝试转换数据(源代码在here和here):
def convert_to_data(val, keep_index=False):
if isinstance(val, Data):
return val
elif type(val) == np.ndarray:
return DenseData(val, [str(i) for i in range(val.shape[1])])
elif str(type(val)).endswith("'pandas.core.series.Series'>"):
return DenseData(val.values.reshape((1,len(val))), list(val.index))
elif str(type(val)).endswith("'pandas.core.frame.DataFrame'>"):
if keep_index:
return DenseDataWithIndex(val.values, list(val.columns), val.index.values, val.index.name)
else:
return DenseData(val.values, list(val.columns))
elif sp.sparse.issparse(val):
if not sp.sparse.isspmatrix_csr(val):
val = val.tocsr()
return SparseData(val)
else:
assert False, "Unknown type passed as data object: "+str(type(val))
所以它基本上不识别
xgboost.core.DMatrix
类型。但是,如果输入 Dataframe 或 numpy 数组,它会通过此转换,但在将其传递给模型时会失败,因为该模型是使用 DMatrix
进行训练的。
解决方法
为了解决这个问题,我将 pandas DataFrame 作为数据传递给
shap.KernelExplainer
,并在提供的函数内添加了到 DMatrix
的转换,该函数返回模型的预测:
def xgb_predict(X, model = loaded_model, target=full_y, features=feature_names):
# Conversion to a DMatrix
full_data = xgb.DMatrix(X, label=target, feature_names=features)
return model.predict(full_data)
# full_X is a pandas DataFrame
explainer = shap.KernelExplainer(model = xgb_predict, data = full_X)