说我有一张张量为张量的图像,其尺寸为(B x C x W x H),其中B是批量大小,C是图像中通道的数量,W和H是宽度和图像的高度。我希望使用transforms.Normalize()
函数相对于数据集跨C个图像通道的均值和标准偏差来标准化我的图像,这意味着我想要一个结果张量形式为1 xC。有没有简单的方法可以做到这一点?
我尝试了torch.view(C, -1).mean(1)
和torch.view(C, -1).std(1)
,但出现错误:
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
编辑
[研究了view()
在PyTorch中的工作原理之后,我知道了为什么我的方法行不通;但是,我仍然不知道如何获取每个通道的均值和标准差。
您只需要以正确的方式重新排列批处理张量:从[B, C, W, H]
到[B, C, W * H]
: