简而言之,我们得到一个形状为
y
的 4D 张量 ( B // s2, D2 // s1, s1, s2)
,其中 y[i,j,...]
表示形状为 (s1,s2) 的矩阵。这些是用于构造形状为(D2,B)的整体大矩阵的分块矩阵,总共有(B//s2) * (D2 //s1)
这样的分块矩阵。这里我们假设所有涉及的数字都是整数。我很清楚如何使用 for 循环来做到这一点:
# y shape ( B // s2, D2 // s1, s1, s2)
result = torch.zeros(D2, B)
for i in range(D2 // s1):
for j in range(B // s2):
result[i * s1: (i + 1) * s1, j * s2: (j + 1) * s2] = y[j,i, ...]
我知道作业可以并行完成。我们可以使用pytorch内置函数来消除两个for循环吗?
nn.Fold
和F.fold
就是为此目的而制作的。如果您查看文档,它会显示:
将滑动局部块数组组合成一个大的包含张量。
考虑一个包含滑动局部块的批量输入张量,例如,形状为的图像块,其中:(N,C×∏(kernel_size),L)
是批次维度,N
是块内值的数量(一个块有C×∏(kernel_size)
个空间位置,每个空间位置包含一个∏(kernel_size)
通道向量),C
是区块总数。L
这与
的输出形状完全相同。此操作通过对重叠值求和,将这些局部块组合成形状为Unfold
的大输出张量。(N,C,output_size[0],output_size[1],…)
在您的情况下,您的输入张量的形状为
(B//s2,D2//s1,s1,s2)
。要了解规格,您有 L = B//s2 * D2//s1
和内核 ∏(kernel_size) = s1 * s2
。由于函数期望 L
位于最后一个位置,因此您需要在展平之前进行一些排列:
y_ = y.permute(3,2,0,1).flatten(0,1).flatten(1)
现在
y_
已成形为(s1*s2, B//s2*D2//s1)
。最后你可以应用折叠:
F.fold(y_, output_size=(D2,B), kernel_size=(s1,s2), stride=(s1,s2))