如何在GridSearchCV中使用TransformedTargetRegressor?

问题描述 投票:1回答:2

我正在尝试在模型管道中使用TransformedTargetRegressor,并在其顶部运行GridSearchCV

这是一个最小的工作示例:

from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.compose import TransformedTargetRegressor


X,y = make_regression()

model_pipe = Pipeline([
    ('model', TransformedTargetRegressor(RandomForestRegressor()))
])

params={'model__n_estimators': [1, 10, 50]}


model = GridSearchCV(model_pipe, param_grid= params)

model.fit(X,y)

此模型导致以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-48-828bdf0e7ede> in <module>
     17 model = GridSearchCV(model_pipe, param_grid= params)
     18 
---> 19 model.fit(X,y)

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
    686                 return results
    687 
--> 688             self._run_search(evaluate_candidates)
    689 
    690         # For multi-metric evaluation, store the best_index_, best_params_ and

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/model_selection/_search.py in _run_search(self, evaluate_candidates)
   1147     def _run_search(self, evaluate_candidates):
   1148         """Search all candidates in param_grid"""
-> 1149         evaluate_candidates(ParameterGrid(self.param_grid))
   1150 
   1151 

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/model_selection/_search.py in evaluate_candidates(candidate_params)
    665                                for parameters, (train, test)
    666                                in product(candidate_params,
--> 667                                           cv.split(X, y, groups)))
    668 
    669                 if len(out) < 1:

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/joblib/parallel.py in __call__(self, iterable)
   1001             # remaining jobs.
   1002             self._iterating = False
-> 1003             if self.dispatch_one_batch(iterator):
   1004                 self._iterating = self._original_iterator is not None
   1005 

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/joblib/parallel.py in dispatch_one_batch(self, iterator)
    832                 return False
    833             else:
--> 834                 self._dispatch(tasks)
    835                 return True
    836 

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/joblib/parallel.py in _dispatch(self, batch)
    751         with self._lock:
    752             job_idx = len(self._jobs)
--> 753             job = self._backend.apply_async(batch, callback=cb)
    754             # A job can complete so quickly than its callback is
    755             # called before we get here, causing self._jobs to

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/joblib/_parallel_backends.py in apply_async(self, func, callback)
    199     def apply_async(self, func, callback=None):
    200         """Schedule a func to be run"""
--> 201         result = ImmediateResult(func)
    202         if callback:
    203             callback(result)

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/joblib/_parallel_backends.py in __init__(self, batch)
    580         # Don't delay the application, to avoid keeping the input
    581         # arguments in memory
--> 582         self.results = batch()
    583 
    584     def get(self):

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/joblib/parallel.py in __call__(self)
    254         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    255             return [func(*args, **kwargs)
--> 256                     for func, args, kwargs in self.items]
    257 
    258     def __len__(self):

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/joblib/parallel.py in <listcomp>(.0)
    254         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    255             return [func(*args, **kwargs)
--> 256                     for func, args, kwargs in self.items]
    257 
    258     def __len__(self):

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/model_selection/_validation.py in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, error_score)
    501     train_scores = {}
    502     if parameters is not None:
--> 503         estimator.set_params(**parameters)
    504 
    505     start_time = time.time()

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/pipeline.py in set_params(self, **kwargs)
    162         self
    163         """
--> 164         self._set_params('steps', **kwargs)
    165         return self
    166 

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/utils/metaestimators.py in _set_params(self, attr, **params)
     48                 self._replace_estimator(attr, name, params.pop(name))
     49         # 3. Step parameters and other initialisation arguments
---> 50         super().set_params(**params)
     51         return self
     52 

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/base.py in set_params(self, **params)
    231 
    232         for key, sub_params in nested_params.items():
--> 233             valid_params[key].set_params(**sub_params)
    234 
    235         return self

~/miniconda3/envs/gymbo/lib/python3.6/site-packages/sklearn/base.py in set_params(self, **params)
    222                                  'Check the list of available parameters '
    223                                  'with `estimator.get_params().keys()`.' %
--> 224                                  (key, self))
    225 
    226             if delim:

ValueError: Invalid parameter n_estimators for estimator TransformedTargetRegressor(check_inverse=True, func=None, inverse_func=None,
                           regressor=RandomForestRegressor(bootstrap=True,
                                                           criterion='mse',
                                                           max_depth=None,
                                                           max_features='auto',
                                                           max_leaf_nodes=None,
                                                           min_impurity_decrease=0.0,
                                                           min_impurity_split=None,
                                                           min_samples_leaf=1,
                                                           min_samples_split=2,
                                                           min_weight_fraction_leaf=0.0,
                                                           n_estimators='warn',
                                                           n_jobs=None,
                                                           oob_score=False,
                                                           random_state=None,
                                                           verbose=0,
                                                           warm_start=False),
                           transformer=None). Check the list of available parameters with `estimator.get_params().keys()`.

当我从管道中删除TransformedTargetRegressor并仅通过随机森林时,此模型将运行。为什么是这样?如上所示,如何在管道中使用TransformedTargetRegressor

python scikit-learn
2个回答
0
投票

我已经找到答案了。 TransformedTargetregressor需要应用于网格搜索估计器,因此

from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.compose import TransformedTargetRegressor


X,y = make_regression()

model_pipe = Pipeline([
    ('model', RandomForestRegressor())
])

params={'model__n_estimators': [1, 10, 50]}


model = TransformedTargetRegressor(GridSearchCV(model_pipe, param_grid= params), func=np.log, inverse_func=np.exp)

model.fit(X,y)

0
投票

RandomForestRegressor作为regressor参数存储在TargetTransformedRegressor中。

因此,为params定义GridSearchCV的正确方法是

params={'model__regressor__n_estimators': [1, 10, 50]}
© www.soinside.com 2019 - 2024. All rights reserved.