PyTorch MSE 损失与直接计算相差 2 倍

问题描述 投票:0回答:1

为什么

torch.nn.functional.mse_loss(x1,x2)
结果与直接计算MSE的结果不同?

我要重现的测试代码:

import torch
import numpy as np

# Think of x1 as predicted 2D coordinates and x2 of ground truth
x1 = torch.rand(10,2)
x2 = torch.rand(10,2)

mse_torch = torch.nn.functional.mse_loss(x1,x2)
print(mse_torch) # 0.1557

mse_direct = torch.nn.functional.pairwise_distance(x1,x2).square().mean()
print(mse_direct) # 0.3314

mse_manual = 0
for i in range(len(x1)) :
    mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
print(mse_manual) # 0.3314 

正如我们所见,火炬的

mse_loss
的结果是
0.1557
,与手动 MSE 计算的结果
0.3314
不同。

事实上,

mse_loss
的结果正好等于直接结果乘以点的维度(此处为2)。

这是怎么回事?

pytorch torch loss-function loss mse
1个回答
0
投票

不同之处在于,torch.nn.function.mse_loss(x1,x2) 在计算平方误差时不对坐标应用求和运算。但是,torch.nn.function.pairwise_distance 和 np.linalg.norm 对坐标应用求和运算。您可以通过以下方式重现计算出的 mse 值:

import torch
import numpy as np

x1 = torch.rand(10,2)
x2 = torch.rand(10,2)

mse_torch = torch.nn.functional.mse_loss(x1,x2)
print(mse_torch) # 0.1557

mse_manual = 0
x3 = torch.zeros(10,2)
for i in range(len(x1)) :
   x3[i,:1] +=(torch.nn.functional.pairwise_distance(x1[i,:1],x2[i,:1],eps=0.0)**2)/len(x1)
   x3[i,1:] += (torch.nn.functional.pairwise_distance(x1[i,1:],x2[i,1:],eps=0.0)**2)/len(x1)
   mse_manual += x3[i]
print(mse_manual.mean()) # 0.1557

mse_manual = 0
for i in range(len(x1)) :
   mse_manual += np.square(x1[i]-x2[i]) / len(x1)
print(mse_manual.mean()) # 0.1557 
© www.soinside.com 2019 - 2024. All rights reserved.