标准偏差值错误

问题描述 投票:0回答:2

您好,我正在尝试评估数据集MNIST的标准差和均值,但是我得到的标准差值有误。这是我的代码:

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F

loader = torch.utils.data.DataLoader(datasets.MNIST(
'../data', train=True, download=True, transform=transform1),
                     batch_size=32,
                     num_workers=0,
                     shuffle=False)

mean = 0.
std = 0.
for images, _ in loader:
    batch_samples = images.size(0) 
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(loader.dataset)
std /= len(loader.dataset)

print("The mean is ", mean)
print("The standard deviation is ", std)

我的问题是,我得到的平均值为0.1307,标准差的取值为0.3015,而不是0.3081。我想我的代码中有错误,但看不到哪里。

您能帮我吗?

非常感谢!

python python-3.x artificial-intelligence pytorch mnist
2个回答
0
投票

torch.std使用批次平均值作为计算的一部分,因此它与在整个数据集上使用torch.std不同,因为它将使用不同的平均值。我们可以使用以下well known expression作为方差以获得所需的结果

Var(X)= E [X ** 2]-E [X] ** 2

因此我们的估计变为

mean = 0.
mean_square = 0.
for images, _ in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1) 
    mean += images.mean(2).sum(0)
    mean_square += (images**2).mean(2).sum(0)

mean /= len(loader.dataset)
mean_square /= len(loader.dataset)
std = torch.sqrt(mean_square - mean**2)

print("The mean is ", mean)
print("The standard deviation is ", std)

0
投票

这里的微小差异是由于平均值和标准偏差在您的代码中以及通常在进行归一化时没有以相同的方式进行计算。

这里,您要做的是在每个图像的所有像素上计算每个批次的均值和标准差,然后取它们的平均值。您最终得到0.3015的值。

现在,如果要计算整个数据集的均值和标准差,则将不会使用相同的均值,并且最终会找到0.3081的值。

© www.soinside.com 2019 - 2024. All rights reserved.