从头开始支持SVM

问题描述 投票:0回答:1

从以下位置引用时,我从头开始构建了一个SVM模型:

https://www.youtube.com/watch?v=UX0f9BNBcsY&list=PLqnslRFeH2Upcrywf-u2etjdxxkL8nl7E&index=7

[当我使用sklearn乳腺癌数据集测试该模型时,测试准确度为0.5

相比之下,我从sklearn svm.SVC(kernel ='linear',C = 10)获得0.98的测试精度。

我认为这可能是因为数据不是线性可分离的,但是无论如何,sklearn模型的表现都很好。

我在这里做错了什么?

代码:

import numpy as np 

class SVM:
    def __init__(self, alpha = 0.001, C = 10, n_iters = 1000):
        self.alpha = alpha                                              # Learning Rate
        self.n_iters = n_iters     
        self.theta = None                                               # Weights
        self.bias = None
        self.C = C                                                      # C = 1/(lambda)

    def fit(self, X_train, y_train):
        # Prepare input y = 0 -> -1
        y_ = np.where(y_train <= 0, -1, 1)
        m, n = X_train.shape

        # init weights and bias
        self.theta = np.zeros(n)
        self.bias = 0

        # Gradient Descent and Updating weights and bias
        for i in range(self.n_iters):  
            for idx, x_i in enumerate(X_train):
                condition = y_[idx] * (x_i @ self.theta - self.bias) >= 1 
                if condition: 
                    self.theta -= self.alpha * (self.theta)
                else:
                    self.theta -= self.alpha * (self.theta - self.C * y_[idx] * x_i)
                    self.bias -= self.alpha * self.C * y_[idx]

    def predict(self, X_test):
        # Note: During Prediction, evaluation threshold is z>0 , z<0
        output =  X_test @ self.theta - self.bias
        return np.sign(output)

    def score(self, X_test, y_test):
        prediction = self.predict(X_test)
        return np.sum(prediction == y_test)/len(y_test)


更新:我已经应用了特征缩放,并且目前正试图通过强行强制学习率来逃避局部最优

machine-learning svm
1个回答
0
投票

Sklearn SVC类使用LibSVM实现,LibSVM实现顺序最小优化算法。

SMO算法比以前的算法,例如Chunking方法和Osuna’s>]算法要快得多。

此实现实际上并不能解决SVM问题,通过使用梯度下降

铰链损耗,svm二次问题没有局部最小值,您只需将SVM二次问题内的某些位置移动即可。

您可能只是想处理二进制分类

而可能遗漏的另一件事,可能您正在使用的数据具有多类

Finally

,要解决SVM问题,您需要解决quadratic问题,对于二次问题,应使用数值算法来解决,或者使用SMO算法来解析地解决问题,有效地。
© www.soinside.com 2019 - 2024. All rights reserved.