使用梯度保存来变换 PyTorch 张量的快速方法

问题描述 投票:0回答:1

问题

我需要一种快速简单的方法来手动将尺寸为 (D, M, M) 的 PyTorch 张量转换为尺寸为 (D*4, M//2, M//2) 的张量,而无需进行卷积。我想使用类似池化的方法,但采用展平和串联操作,其中内核大小始终为 2,步幅也为 2,以将采样降低到一半。保持梯度至关重要。

输入示例:

将 (3, 4, 4) 变换为 (12, 2, 2):

[[[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11], [12, 13, 14, 15]],

[[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]],

[[32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]]

所需输出:

[[[ 0,  1,  4,  5, 16, 17, 20, 21, 32, 33, 36, 37],

[ 2,  3,  6,  7, 18, 19, 22, 23, 34, 35, 38, 39]],

[[ 8,  9, 12, 13, 24, 25, 28, 29, 40, 41, 44, 45],

[10, 11, 14, 15, 26, 27, 30, 31, 42, 43, 46, 47]]]

测试代码:

# Generate the input tensor
input_tensor = torch.arange(48).reshape(3, 4, 4)

# Get Shape
n, m, _ = input_tensor.shape

# DO CODE operation

#check output output_tensor[:,0,0] == [ 0,  1,  4,  5, 16, 17, 20, 21, 32, 33, 36, 37]
.....

我尝试创建一个中间步骤以获得所需的输出:

patches = input_tensor.unfold(1, 2, 2).unfold(2, 2, 2).reshape(n, m//2,m//2, 4)

输出:

 output: tensor([[[[ 0,  1,  4,  5], [ 2,  3,  6,  7]],

     [[ 8,  9, 12, 13],
      [10, 11, 14, 15]]],


    [[[16, 17, 20, 21],
      [18, 19, 22, 23]],

     [[24, 25, 28, 29],
      [26, 27, 30, 31]]],


    [[[32, 33, 36, 37],
      [34, 35, 38, 39]],

     [[40, 41, 44, 45],
      [42, 43, 46, 47]]]])

但我仍然需要将这个补丁转换为 12,2,2 的向量并保持正确的顺序。

编辑 (D2, M//2, M//2) > (D4, M//2, M//2)

pytorch torch
1个回答
0
投票

你几乎已经做到了。在获得形状为 (n, m//2, m//2, 4) 的补丁后,您必须展平最后一个维度并将张量排列为正确的顺序,

torch.permute
就是这里的方法。这是完整的代码:

import torch

# Generate the input tensor
input_tensor = torch.arange(48).reshape(3, 4, 4)

# Get Shape
n, m, _ = input_tensor.shape

# Create patches
patches = input_tensor.unfold(1, 2, 2).unfold(2, 2, 2).reshape(n, m//2, m//2, 4)

# Flatten the last dimension and permute the tensor to the correct order
output_tensor = patches.permute(1,2,0,3).reshape(m//2, m//2, n*4)

print(output_tensor)
© www.soinside.com 2019 - 2024. All rights reserved.