Pytorch 逻辑回归模型没有学习,每个时期都给出相同的预测

问题描述 投票:0回答:0

我正在尝试使用 Logistic 回归使用 pytorch 执行 ML 任务,以解决具有 17 个数字属性作为输入的二元分类问题。

这是我的模型:

class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(17, 1)

    def forward(self, x):
        output = torch.sigmoid(self.linear(x))
        return output

纪元代码:

self.loss_func = torch.nn.BCELoss()
self.net = LogisticRegression()
self.opti = optim.Adam(self.net.parameters(), lr=0.01)
for epoch in range(local_epochs):
    for data, label in self.train_dl:
        data, label = data.to(self.dev), label.to(self.dev)
        self.opti.zero_grad()
        preds = self.net(data)
    
        preds = torch.round(preds)
        loss = self.loss_func(preds, label.to(torch.float32))
        loss.backward()
        self.opti.step()

计算精度的代码:

sum_accu = 0
num = 0
for data, label in self.test_dl:
    data, label = data.to(self.dev), label.to(self.dev)
    preds = self.net(data)
        print(preds)
    rounded_preds = torch.round(preds)
    sum_accu += (rounded_preds == label).float().mean()
    num += 1
accuracy = sum_accu/num

每个时代的准确度总是相同的,所以当通过在每个时代结束时打印“preds”进行调查时,我观察到这些值完全相同,一遍又一遍地打印。这是从我的终端复制的两个这样的实例。

tensor([[0.3958],
        [0.3945],
        [0.7899],
        [0.7307],
        [0.4597],
        [0.6161],
        [0.4561],
        [0.5829],
        [0.5005],
        [0.5272],
        [0.5041],
        [0.5438],
        [0.5004],
        [0.3623],
        [0.5250],
        [0.5050],
        [0.7511],
        [0.5194],
        [0.4359],
        [0.4961],
        [0.4741],
        [0.2358],
        [0.5108],
        [0.6026],
        [0.5581],
        [0.7312],
        [0.5047],
        [0.5456],
        [0.3654],
        [0.3890],
        [0.5150],
        [0.4835],
        [0.5635],
        [0.6935],
        [0.4463],
        [0.5423],
        [0.6504],
        [0.5503],
        [0.5756],
        [0.6759],
        [0.5986],
        [0.4919],
        [0.4506],
        [0.5991],
        [0.6428],
        [0.5743],
        [0.4644],
        [0.5459],
        [0.4973],
        [0.5363],
        [0.4168],
        [0.4297],
        [0.4899],
        [0.4357],
        [0.4491],
        [0.7026],
        [0.5691],
        [0.3235],
        [0.5053],
        [0.6034],
        [0.5739],
        [0.3975],
        [0.5293],
        [0.4870],
        [0.4976],
        [0.4651],
        [0.7204],
        [0.5242],
        [0.3649],
        [0.5264],
        [0.4867],
        [0.3726],
        [0.5398],
        [0.5339],
        [0.5334],
        [0.3884],
        [0.4933],
        [0.4489],
        [0.3119],
        [0.3826],
        [0.4244],
        [0.5612],
        [0.5480],
        [0.5575],
        [0.5411],
        [0.6343],
        [0.5174],
        [0.4965],
        [0.5172],
        [0.4460],
        [0.5153],
        [0.3980],
        [0.5763],
        [0.4840],
        [0.4682],
        [0.5400],
        [0.7034],
        [0.4811],
        [0.4481]])
tensor([[0.3958],
        [0.3945],
        [0.7899],
        [0.7307],
        [0.4597],
        [0.6161],
        [0.4561],
        [0.5829],
        [0.5005],
        [0.5272],
        [0.5041],
        [0.5438],
        [0.5004],
        [0.3623],
        [0.5250],
        [0.5050],
        [0.7511],
        [0.5194],
        [0.4359],
        [0.4961],
        [0.4741],
        [0.2358],
        [0.5108],
        [0.6026],
        [0.5581],
        [0.7312],
        [0.5047],
        [0.5456],
        [0.3654],
        [0.3890],
        [0.5150],
        [0.4835],
        [0.5635],
        [0.6935],
        [0.4463],
        [0.5423],
        [0.6504],
        [0.5503],
        [0.5756],
        [0.6759],
        [0.5986],
        [0.4919],
        [0.4506],
        [0.5991],
        [0.6428],
        [0.5743],
        [0.4644],
        [0.5459],
        [0.4973],
        [0.5363],
        [0.4168],
        [0.4297],
        [0.4899],
        [0.4357],
        [0.4491],
        [0.7026],
        [0.5691],
        [0.3235],
        [0.5053],
        [0.6034],
        [0.5739],
        [0.3975],
        [0.5293],
        [0.4870],
        [0.4976],
        [0.4651],
        [0.7204],
        [0.5242],
        [0.3649],
        [0.5264],
        [0.4867],
        [0.3726],
        [0.5398],
        [0.5339],
        [0.5334],
        [0.3884],
        [0.4933],
        [0.4489],
        [0.3119],
        [0.3826],
        [0.4244],
        [0.5612],
        [0.5480],
        [0.5575],
        [0.5411],
        [0.6343],
        [0.5174],
        [0.4965],
        [0.5172],
        [0.4460],
        [0.5153],
        [0.3980],
        [0.5763],
        [0.4840],
        [0.4682],
        [0.5400],
        [0.7034],
        [0.4811],
        [0.4481]])

我原以为随着模型学习和更新参数,每个时期的输出都会发生变化,但是每个时期相同的输出导致相同的准确度,这清楚地表明模型没有学习。尝试使用从 0.1 到 0.00001 的学习率。

pytorch logistic-regression
© www.soinside.com 2019 - 2024. All rights reserved.