我想使用 sklearn_genic 库中的 GASearchCV 来优化我的 keras 模型的超参数。首先,我定义了一个函数来构建我的模型。然后我使用 scikeras 来创建我的估算器。最后我调用GASearchCV来优化超参数,然后将其拟合到数据上。代码在这里:
import tensorflow as tf
from tensorflow import keras
from scikeras.wrappers import KerasClassifier
from sklearn_genetic import GASearchCV
def ann_model_ga(number_of_hidden_layer=1,
number_of_neurons=50,
optimizer="adam"):
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=X_train.shape[1:]))
for hidden_layer in range(number_of_hidden_layer):
model.add(keras.layers.Dense(number_of_neurons))
model.add(keras.layers.Dense(10, activation="softmax"))
model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy", "AUC"],)
return model
param_grid = {'number_of_hidden_layer': Integer(1, 5),
'number_of_neurons': Integer(100, 200),
'optimizer': Categorical(["adam", "sgd"])}
estimator = KerasClassifier(build_fn=ann_model_ga)
evolved_estimator = GASearchCV(estimator=estimator,
cv=5,
scoring='accuracy',
population_size=10,
generations=35,
tournament_size=3,
elitism=True,
crossover_probability=0.8,
mutation_probability=0.1,
param_grid=param_grid,
criteria='max',
algorithm='eaMuPlusLambda',
n_jobs=-1,
verbose=True,
keep_top_k=4)
history_ga = evolved_estimator.fit(X_train, y_train)
但是我收到了这个错误:
ValueError Traceback (most recent call last)
<ipython-input-17-36285fd9a509> in <cell line: 1>()
----> 1 history_ga = evolved_estimator.fit(X_train, y_train)
4 frames
/usr/local/lib/python3.10/dist-packages/scikeras/wrappers.py in set_params(self, **params)
1163 # Give a SciKeras specific user message to aid
1164 # in moving from the Keras wrappers
-> 1165 raise ValueError(
1166 f"Invalid parameter {param} for estimator {self.__name__}."
1167 "\nThis issue can likely be resolved by setting this parameter"
ValueError: Invalid parameter number_of_hidden_layer for estimator KerasClassifier.
This issue can likely be resolved by setting this parameter in the KerasClassifier constructor:
`KerasClassifier(number_of_hidden_layer=1)`
Check the list of available parameters with `estimator.get_params().keys()`
有人可以帮我解决这个问题吗?