假设我们有 3 个大小为
(B, C, H, W)
的张量,其中 B
是批量大小,C
是通道维度。我期望计算这 3 个张量沿通道维度的协方差。
我尝试过以下代码:
x1_mean = x1.mean(dim=1).unsqueeze(dim=1)
x2_mean = x2.mean(dim=1).unsqueeze(dim=1)
x3_mean = x3.mean(dim=1).unsqueeze(dim=1)
out = torch.matmul(torch.matmul(x1 - x1_mean, x2 - x2_mean), x3 - x3_mean)
只是想知道我的代码是否有意义。还有另一种方法来计算协方差吗? 任何帮助将不胜感激。
您计算沿通道维度的三个张量的协方差的方法是一个好的开始,但似乎对于如何计算协方差可能存在误解,特别是对于多个张量。您正在计算沿通道维度的平均值,然后将其解压缩以匹配原始张量的形状。这部分对于张量的均值居中是正确的。但是,在代码中使用
torch.matmul
无法正确计算协方差。协方差通常涉及两组变量(而不是三组)之间的成对计算,并且计算方式不同。
综上所述,假设您想在三个张量之间进行某种形式的多元协方差,以下代码可以为您完成工作:
# step 1: reshape tensors
B, C, H, W = x1.shape
x_combined = torch.cat([x1, x2, x3], dim=1) # Shape: (B, C*3, H, W)
x_combined = x_combined.reshape(B, C*3, H*W) # Shape: (B, C*3, H*W)
# step 2: mean centering
mean_centered = x_combined - x_combined.mean(dim=2, keepdim=True)
# step 3: covariance matrix calculation
cov_matrix = torch.matmul(mean_centered, mean_centered.transpose(1, 2)) / (H*W - 1)
此代码片段将生成一批协方差矩阵,每个大小为 (C3, C3),表示每个空间位置处所有三个张量的每对通道之间的协方差。每个矩阵都是批次中一个示例的多元协方差矩阵。