sklearn SVM自定义内核引发ValueError:X.shape [0]应该等于X.shape [1]

问题描述 投票:0回答:1

我正在尝试实现一个自定义内核,正是指数Chi-Squared内核,作为参数传递给sklearn svm函数,但是当我运行它时会引发后续错误:ValueError:X.shape [0]应该等于X.shape [1]

我读到了numpy函数执行的广播操作,以加快计算速度,但我无法管理错误。

代码是:

import numpy as np
from sklearn import svm, datasets

# import the iris dataset (http://en.wikipedia.org/wiki/Iris_flower_data_set)
iris = datasets.load_iris()
train_features = iris.data[:, :2]  # Here we only use the first two features.
train_labels = iris.target


def my_kernel(x, y):
    gamma = 1
    return np.exp(-gamma * np.divide((x - y) ** 2, x + y))


classifier = svm.SVC(kernel=my_kernel)

classifier = classifier.fit(train_features, train_labels)

print "Train Accuracy : " + str(classifier.score(train_features, train_labels))

有帮助吗?

python numpy machine-learning classification svm
1个回答
0
投票

我相信已经为你实现了Chi-Squared内核(在from sklearn.metrics.pairwise import chi2_kernel中)。

像这样

from functools import partial

from sklearn import svm, datasets
from sklearn.metrics.pairwise import chi2_kernel

# import the iris dataset (http://en.wikipedia.org/wiki/Iris_flower_data_set)
iris = datasets.load_iris()
train_features = iris.data[:, :2]  # Here we only use the first two features.
train_labels = iris.target

my_chi2_kernel = partial(chi2_kernel, gamma=1)

classifier = svm.SVC(kernel=my_chi2_kernel)

classifier = classifier.fit(train_features, train_labels)

print("Train Accuracy : " + str(classifier.score(train_features, train_labels)))

====================

编辑:

事实证明,问题实际上是关于如何实现卡方内核。我对此的镜头是: -

def my_chi2_kernel(X):
    gamma = 1
    nom = np.power(X[:, np.newaxis] - X, 2)
    denom = X[:, np.newaxis] + X
    # NOTE: We need to fix some entries, since division by 0 is an issue here.
    #       So we take all the index of would be 0 denominator, and fix them.
    zero_denom_idx = denom == 0
    nom[zero_denom_idx] = 0
    denom[zero_denom_idx] = 1

    return np.exp(-gamma * np.sum(nom / denom, axis=len(X.shape)))

所以在本质上x - yx + y在OP的尝试中是错误的,因为它不是成对减法或加法。

奇怪的是,自定义版本似乎比sklearn的cython版本更快(至少对于小数据集?)

© www.soinside.com 2019 - 2024. All rights reserved.