给定几个分块矩阵,得到整体大矩阵

问题描述 投票:0回答:1

简而言之,我们得到一个形状为

y
的 4D 张量
( B // s2, D2 // s1, s1, s2)
,其中
y[i,j,...]
表示形状为 (s1,s2) 的矩阵。这些是用于构造形状为(D2,B)的整体大矩阵的分块矩阵,总共有
(B//s2) * (D2 //s1)
这样的分块矩阵。这里我们假设所有涉及的数字都是整数。我很清楚如何使用 for 循环来做到这一点:

# y shape ( B // s2, D2 // s1, s1, s2)
result = torch.zeros(D2, B)
for i in range(D2 // s1):
    for j in range(B // s2):
         result[i * s1: (i + 1) * s1, j * s2: (j + 1) * s2] = y[j,i, ...]

我知道作业可以并行完成。我们可以使用pytorch内置函数来消除两个for循环吗?

python matrix pytorch reshape tensor
1个回答
0
投票

这称为折叠操作,

nn.Fold
F.fold
就是为此目的而制作的。如果您查看文档,它会显示:

将滑动局部块数组组合成一个大的包含张量。
考虑一个包含滑动局部块的批量输入张量,例如,形状为

(N,C×∏(kernel_size),L)
的图像块,其中:

  • N
    是批次维度,
  • C×∏(kernel_size)
    是块内值的数量(一个块有
    ∏(kernel_size)
    个空间位置,每个空间位置包含一个
    C
    通道向量),
  • L
    是区块总数。

这与

Unfold
的输出形状完全相同。此操作通过对重叠值求和,将这些局部块组合成形状为
(N,C,output_size[0],output_size[1],…)
的大输出张量。

在您的情况下,您的输入张量的形状为

(B//s2,D2//s1,s1,s2)
。要了解规格,您有
L = B//s2 * D2//s1
和内核
∏(kernel_size) = s1 * s2
。由于函数期望
L
位于最后一个位置,因此您需要在展平之前进行一些排列:

y_ = y.permute(3,2,0,1).flatten(0,1).flatten(1)

现在

y_
已成形为
(s1*s2, B//s2*D2//s1)
。最后你可以应用折叠:

F.fold(y_, output_size=(D2,B), kernel_size=(s1,s2), stride=(s1,s2))
© www.soinside.com 2019 - 2024. All rights reserved.