神经网络输入缩放

问题描述 投票:0回答:1

我在CIFAR-10数据集上训练了一个简单的完全连接的网络:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(3*32*32, 300, bias=False)
        self.fc2 = nn.Linear(300, 10, bias=False)

    def forward(self, x):
        x = x.reshape(250, -1)
        self.x2 = F.relu(self.fc1(x))
        x = self.fc2(self.x2)
        return x


def train():
    # The output of torchvision datasets are PILImage images of range [0, 1].
    transform = transforms.Compose([transforms.ToTensor()])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=250, shuffle=True, num_workers=4)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=4)

    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001)

    for epoch in range(20):
        correct = 0
        total = 0
        for data in trainloader:
            inputs, labels = data
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        acc = 100. * correct / total

在指定的参数下,经过20个周期后,该网络的测试精度达到〜50%。请注意,我没有对输入进行任何白化(每个通道均无减法)

接下来,通过将outputs = net(inputs)替换为outputs = net(inputs*255),将模型输入放大255。更改后,网络不再收敛。我看了看梯度,经过几次迭代,它们似乎爆炸性地增长,导致所有模型输出均为零。我想了解为什么会这样。

另外,我尝试将学习率降低255。这很有帮助,但网络的准确度仅为〜43%。再次,我不明白为什么这会有所帮助,更重要的是为什么与原始设置相比,准确性仍然会下降。

编辑:忘记提及我在此网络中不使用偏见。

tensorflow neural-network pytorch backpropagation
1个回答
0
投票

CIFAR-10比MNIST数据集要重要得多,因此,完全连接的神经网络不具备准确预测这些图像所需的表示能力。 CNN是除MNIST之外的任何图像分类任务的方法。不幸的是,〜50%的准确性是使用完全连接的神经网络所能获得的最大精度。

[Here's有关CIFAR-10上不同神经网络类型的性能的更多信息。

© www.soinside.com 2019 - 2024. All rights reserved.