这是我的模型,
model = Sequential()
model.add(Embedding(vocab_size,output_dim=100, weights = [embedding_matrix], input_length= 1000, trainable= False))
model.add(Conv1D(32,4,activation='relu'))
model.add(MaxPooling1D())
model.add(Conv1D(64,4,activation='relu'))
model.add(MaxPooling1D())
model.add(LSTM(units=128))
model.add(Dense(1, activation='sigmoid'))
model.summary()
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
My shap code
import shap
explainer = shap.DeepExplainer(model, x_train[:100])
shap_values = explainer.shap_values(x_test[:10])
我面临的错误是,
AttributeError: in user code:
File "/usr/local/lib/python3.10/dist-packages/shap/explainers/_deep/deep_tf.py", line 254, in grad_graph *
out = self.model(shap_rAnD)
File "/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.10/dist-packages/shap/explainers/_deep/deep_tf.py", line 385, in custom_grad
out = op_handlers[type_name](self, op, *grads) # we cut off the shap_ prefix before the lookup
File "/usr/local/lib/python3.10/dist-packages/shap/explainers/_deep/deep_tf.py", line 674, in handler
return linearity_with_excluded_handler(input_inds, explainer, op, *grads)
File "/usr/local/lib/python3.10/dist-packages/shap/explainers/_deep/deep_tf.py", line 681, in linearity_with_excluded_handler
assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
File "/usr/local/lib/python3.10/dist-packages/shap/explainers/_deep/deep_tf.py", line 231, in _variable_inputs
out[i] = t.name in self.between_tensors
AttributeError: Exception encountered when calling layer 'lstm' (type LSTM).
'TFDeep' object has no attribute 'between_tensors'
Call arguments received by layer 'lstm' (type LSTM):
• inputs=tf.Tensor(shape=(1000, 247, 64), dtype=float32)
• mask=None
• training=False
• initial_state=None
尝试修复许多方法仍然面临同样的问题。有人请帮助我!
您需要禁用 TF 中的现代行为,因为 SHAP 仍在 TF 中使用旧的实现。在使用 SHAP 之前调用此方法:
tensorflow.compat.v1.disable_v2_behavior()