为什么我的简单 MATLAB 线性回归梯度下降不起作用

问题描述 投票:0回答:1

我开始学习线性回归。我想自己实现梯度下降。我写了下面的代码。

%% Linear regression

close all;

dataset =load('accidents');
data = dataset.hwydata;
x = data(:,14);
y  =data(:,4);
%% Gradient descent
% We want to minimize a cost function and GD achieves that iteratively.
% J(w,b) =(y-y_est)^2

w = 0;
b = 0;
alpha =.000001; % I tried various alphas like 0.01, .1 etc. (Not working)
for i =1: 100
    y_est = w*x + b;
    J = mean((y-y_est).^2)
    temp_w = w + alpha*(mean(x.*(y-w*x-b)));
    temp_b = b + alpha*(mean(y-w*x-b));
    w =temp_w
    b =temp_b
end

由于某种原因,它不起作用。看来我的算法没有收敛。

我预计该算法能够很好地收敛,因为均方误差成本函数是凸的。

matlab linear-regression gradient-descent
1个回答
0
投票

简短的回答:您需要沿着梯度方向进行某种线搜索。

正如我在查询中提到的,较大的 x 和 y 值可能会导致算法发散,这就是您的程序中发生的情况,x 值高达大约。 3e7 和偏导数 w.r.t. 的初始值w 约为 1e10。

对于非常简单的线搜索:将当前评价函数(下面代码中的

Jcurrent
)与使用
temp_w
temp_b
和当前 alpha(下面代码中的
Jnew
)计算的函数进行比较。

如果

Jnew >= Jcurrent
按某个因子减少 alpha,并使用新的 alpha 和当前梯度重新计算
temp_w
temp_b
。使用更新后的
Jnew
temp_w
重新计算
temp_b
。重复直到
Jnew < Jcurrent

在此行搜索之后,您至少有两个选择:1)将 alpha 重置为其起始值(代码中的

alphaOrig
)或 2)保留当前 alpha。

请注意,此线搜索远非最佳。它仅搜索评价函数的减少并接受它,无论减少有多小。这会导致收敛缓慢。如果您希望我建议更好的线性搜索方法,请告诉我。

江淮汽车

%% Linear regression

clear; 
% Delete all figures
figureList = findobj('type', 'figure');
if ~isempty(figureList)
    delete(figureList);
end
    alphaOrig = 1e-4;
    dataset =load('accidents');
    data = dataset.hwydata;
    x = data(:,14);
    y  =data(:,4);
% ..Check the ranges of x and y    
    fprintf(1, 'max(x) = %g\tmax(y) = %g\n', max(abs(x)), max(abs(y)) );
%% Gradient descent
% We want to minimize a cost function and GD achieves that iteratively.
% J(w,b) =(y-y_est)^2

    w = 0;
    b = 0;
    alpha = alphaOrig;
% ..Initial value of merit function
    Jcurrent = mean((y - w*x - b).^2);
    for i =1: 100
        y_est = w*x + b;
        J = mean((y-y_est).^2);
    % ..Components of the gradient
        dJdw = mean(x.*(y-w*x-b));
        dJdb = mean(y-w*x-b);
        fprintf(1, '%d: J(%g,%g) = %.8g\t', i, w,b,Jcurrent);
        fprintf(1, 'dJ/dw = %g; dJ/db = %g\n', dJdw, dJdb);
    % ..Crude line search
        while true % Loop not in the original at stackoverflow
            temp_w = w + alpha*dJdw;
            temp_b = b + alpha*dJdb;
            Jnew = mean((y - temp_w*x - temp_b).^2);
            fprintf(1, '\talpha = %g;\tJ(%g,%g) = %.8g; \n', ...
                alpha, temp_w, temp_b, Jnew);
            if Jnew < Jcurrent
                break;
            end
            alpha = 0.1*alpha; % <== reduce alpha
            if alpha == 0
                error("Sorry. It's not going well");
            end
        end
        Jcurrent = Jnew;
%         alpha = alphaOrig; % This resets alpha
        w =temp_w;
        b =temp_b;
    end
    fmt = "%10s: w = %g; b= %g\n";
%     fprintf(1, fmt, "Noise-free", wa, ba);
    fprintf(1, fmt, "Estimate", w, b);
    fig = figure(1);clf
    plot(x,y, 'linestyle', 'none','marker', 'o');
    hold on
% ..Show the least-squares fit
    t = [0.9*min(x),1.2*max(x)];
    plot(t,w*t+b, 'linestyle', '--', 'color', 'black');
© www.soinside.com 2019 - 2024. All rights reserved.