Pytorch:每轮接受相同的测试评估

问题描述 投票:0回答:0

我的联邦学习设置有问题。我有一些权重,我想使用 PyTorch 在每一轮中评估(测试)。现在我对模型的损失和准确性感兴趣。我使用的模型如下:

class Net(nn.Module):
    """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

以及产生结果的方法:

def load_testset():
    """Load CIFAR-10 (test set)."""
    trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    testset = CIFAR10("./data", train=False, download=True, transform=trf)
    return DataLoader(testset), testset

def get_AccuracyAndLoss(weights):
    # Load the existing weights list
    weights_list = weights

    for i, weights in enumerate(weights_list):
        layer_name = 'layer_' + str(i)
        setattr(Net(), layer_name, nn.Parameter(torch.from_numpy(weights)))

    # Load model and data (simple CNN, CIFAR-10)
    net = Net().to(DEVICE)
    testloader, test_set = load_testset()

    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)

            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Test_loss: %.3f, accuracy: %.2f' % (loss/len(testloader), correct / total))

问题是我每一轮都得到相同的结果,大约是 Test_loss:2.306,准确度:0.10,我不知道模型是否不正确或权重插入的方式。我对深度学习还很陌生,因此不胜感激简单的答案。

大部分代码基于此处的 pytorch 深度学习教程:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

我正在使用 python 3.10

python deep-learning pytorch neural-network evaluation
© www.soinside.com 2019 - 2024. All rights reserved.