如何检测参数网格中哪些值是允许的?

问题描述 投票:0回答:3

我已经开始研究一个项目,其中我需要检测给定 scikit-learn 估计器的“可训练”参数,如果可能的话,找到分类变量的允许值(以及连续变量的合理间隔)。 我可以使用

estimator.get_params()

获取带有参数的字典,然后使用

estimator.set_params(**{'var1':val1, 'var2':val2})
设置值,依此类推。
例如,对于 KNN 分类器,我们有以下参数字典:

{'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}

现在,我可以使用值的类型来推断哪些是分类(

str

类型)、连续(

float
)、离散(
int
)等等。一个可能相关的问题是默认设置为
NoneType
的参数,但无论如何我可能不会碰这些,这是有充分理由的。
现在的挑战变成推断和定义参数网格以用于例如

RandomizedSearchCV

。对于离散变量和连续变量,问题可以使用例如

try
-
except
块与 scipy.stats 模块的组合,可能将间隔限制在默认值周围的“附近”(但同时要小心不要将例如
n_jobs
设置为一些疯狂的值——可能需要硬编码,或者稍后显式设置)。如果您有类似的经验,并且有一些技巧/窍门,我很想听听它们。
但现在真正的问题是:如何推断例如

algorithm

允许的值实际上是

{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
??
我刚刚开始研究这个问题,如果我们尝试将其设置为某个不允许的值,也许我们可以解析收到的错误消息?我在这里寻找好的想法,因为我想避免手动执行此操作(如果必须的话我会这样做,但看起来相当不优雅......)

python scikit-learn parameter-passing cross-validation
3个回答
0
投票

因此,我发布我的“解决方案”,以便其他人可以接管并可能对其进行改进。请参阅以下片段:

import re from pprint import pprint from sklearn.neighbors import KNeighborsClassifier knn = KNeighborsClassifier() doc = knn.__doc__ # Get the doc string #from sklearn.svm import SVC #svc = SVC() #doc = svc.__doc__ pattern = "([a-zA-Z_]+\s:\s)|(-\s*)'([a-zA-Z_]+)'" # Define search pattern re.compile(pattern) matches = re.findall(pattern, doc) clf_params = {} previous_param = '' for param, _, value in matches: if ":" in param and param[-4]!="_": # 'Hack-y' if param not in clf_params.keys(): clf_params[param] = list() previous_param = param else: if len(value)>0: clf_params[previous_param].append(value) pprint(clf_params)

此片段打印

{'algorithm : ': ['ball_tree', 'kd_tree', 'brute', 'auto'], 'leaf_size : ': [], 'metric : ': [], 'metric_params : ': [], 'n_jobs : ': [], 'n_neighbors : ': [], 'p : ': [], 'weights : ': ['uniform', 'distance']}

哪个是正确的。

但是,如果我们对

SVC().__doc__

重复相同的过程,我们会发现它失败了。


我希望

有人

觉得这有点用。


0
投票
splitlines()

:

 的大力帮助
liner = str(LinearSVC().__doc__).split('Parameters\n ----------\n')[1].split('\n\n Attributes\n')[0].replace('\n ', '\n').splitlines()

这不会创建字典,但足够简单,可以从文档字符串中提取解释的“参数”部分,其中解释了所有参数,并列出了所有列出的可能/预期/接受的值输入,这些输入很好地缩进了一,选项卡,现在我们可以使用带有条件的简单循环,使用“:”作为锚点来识别
可能/预期/接受的值输入

for i in liner: ...: if " : " in i: #<<< the key is to use " : " as our anchor ...: print(i)

最终结果,打印到:

penalty : str, 'l1' or 'l2' (default='l2') loss : str, 'hinge' or 'squared_hinge' (default='squared_hinge') dual : bool, (default=True) tol : float, optional (default=1e-4) C : float, optional (default=1.0) multi_class : str, 'ovr' or 'crammer_singer' (default='ovr') fit_intercept : bool, optional (default=True) intercept_scaling : float, optional (default=1) class_weight : {dict, 'balanced'}, optional verbose : int, (default=0) random_state : int, RandomState instance or None, optional (default=None) max_iter : int, (default=1000)

很高兴我可以分享,如果其他人需要完整的文档字符串参数打印输出,只需使用:

print(str(LinearSVC().__doc__).split('Parameters\n ----------\n')[1].split('\n\n Attributes\n')[0].replace('\n ', '\n'))

编辑:

如果这不打算打印出来 - 将其作为字符串对象的最佳方法是使用列表理解,但它需要一些丑陋的替换,因为文档字符串中有 extential 表示法: docstring_short = str([i for i in liner.splitlines() if " : " in i]).replace('[" ', '').replace(' ', ',\n').replace('", "', '').replace('", \'', '').replace("', '", '').replace("', \"", '').replace(']', '')



0
投票
示例:

get_estimator_params(“SVC”)

预期结果:

{ “C”:1.0, “break_ties”:假, “缓存大小”:200, “class_weight”:空, “coef0”:0.0, “决策函数形状”:[ “ovo”、“ovr” ], “程度”:3, “伽玛”:[ “缩放”、“自动” ], “核心”: [ “线性”、“聚”、“rbf”、“sigmoid”、“预计算” ], “最大迭代器”:-1, “概率”:假, “随机状态”:空, “收缩”:正确, “公差”:0.001, “详细”:假 }

© www.soinside.com 2019 - 2024. All rights reserved.