如何在 PyTorch 中计算三个张量的批量协方差?

问题描述 投票:0回答:1

假设我们有 3 个大小为

(B, C, H, W)
的张量,其中
B
是批量大小,
C
是通道维度。我期望计算这 3 个张量沿通道维度的协方差。

我尝试过以下代码:

x1_mean = x1.mean(dim=1).unsqueeze(dim=1)
x2_mean = x2.mean(dim=1).unsqueeze(dim=1)
x3_mean = x3.mean(dim=1).unsqueeze(dim=1)
out = torch.matmul(torch.matmul(x1 - x1_mean, x2 - x2_mean), x3 - x3_mean)

只是想知道我的代码是否有意义。还有另一种方法来计算协方差吗? 任何帮助将不胜感激。

pytorch covariance
1个回答
0
投票

您计算沿通道维度的三个张量的协方差的方法是一个好的开始,但似乎对于如何计算协方差可能存在误解,特别是对于多个张量。您正在计算沿通道维度的平均值,然后将其解压缩以匹配原始张量的形状。这部分对于张量的均值居中是正确的。但是,在代码中使用

torch.matmul
无法正确计算协方差。协方差通常涉及两组变量(而不是三组)之间的成对计算,并且计算方式不同。

综上所述,假设您想在三个张量之间进行某种形式的多元协方差,以下代码可以为您完成工作:

# step 1: reshape tensors
B, C, H, W = x1.shape
x_combined = torch.cat([x1, x2, x3], dim=1)  # Shape: (B, C*3, H, W)
x_combined = x_combined.reshape(B, C*3, H*W) # Shape: (B, C*3, H*W)

# step 2: mean centering
mean_centered = x_combined - x_combined.mean(dim=2, keepdim=True)

# step 3: covariance matrix calculation
cov_matrix = torch.matmul(mean_centered, mean_centered.transpose(1, 2)) / (H*W - 1)

此代码片段将生成一批协方差矩阵,每个大小为 (C3, C3),表示每个空间位置处所有三个张量的每对通道之间的协方差。每个矩阵都是批次中一个示例的多元协方差矩阵。

© www.soinside.com 2019 - 2024. All rights reserved.