我正在尝试在 python 中从头开始实现线性回归。
这是我尝试过的:
class LinearRegression:
def __init__(
self,
features: np.ndarray[np.float64],
targets: np.ndarray[np.float64],
) -> None:
self.features = np.concatenate((np.ones((features.shape[0], 1)), features), axis=1)
self.targets = targets
self.params = np.random.randn(features.shape[1] + 1)
self.num_samples = features.shape[0]
self.num_feats = features.shape[1]
self.costs = []
def hypothesis(self) -> np.ndarray[np.float64]:
return np.dot(self.features, self.params)
def cost_function(self) -> np.float64:
pred_vals = self.hypothesis()
return (1 / (2 * self.num_samples)) * np.dot((pred_vals - self.targets).T, pred_vals - self.targets)
def update(self, alpha: np.float64) -> None:
self.params = self.params - (alpha / self.num_samples) * (self.features.T @ (self.hypothesis() - self.targets))
def gradientDescent(self, alpha: np.float64, threshold: np.float64, max_iter: int) -> None:
converged = False
counter = 0
while not converged:
counter += 1
curr_cost = self.cost_function()
self.costs.append(curr_cost)
self.update(alpha)
new_cost = self.cost_function()
if abs(new_cost - curr_cost) < threshold:
converged = True
if counter > max_iter:
converged = True
我使用了这样的课程:
regr = LinearRegression(features=np.linspace(0, 1000, 200, dtype=np.float64).reshape((20, 10)), targets=np.linspace(0, 200, 20, dtype=np.float64))
regr.gradientDescent(0.1, 1e-3, 1e+3)
regr.cost_function()
但是,我收到以下错误:
RuntimeWarning: overflow encountered in scalar power
return (1 / (2 * self.num_samples)) * (la.norm(self.hypothesis() - self.targets) ** 4)
RuntimeWarning: invalid value encountered in scalar subtract
if abs(new_cost - curr_cost) < threshold:
RuntimeWarning: overflow encountered in matmul
self.params = self.params - (alpha / self.num_samples) * (self.features.T @ (self.hypothesis() - self.targets))
任何人都可以帮助我了解到底出了什么问题吗?
它溢出了,因为您在示例中使用了大量数字。尝试使用:
regr = LinearRegression(features=np.linspace(0, 1000, 200, dtype=np.float64).reshape((20, 10))/1000, targets=np.linspace(0, 200, 20, dtype=np.float64)/1000)
regr.gradientDescent(0.1, 1e-3, 1e+3)
regr.cost_function()
它给我的输出为 0.00474225348416323。