背景:
我正在开发一个程序,该程序首先沿“列”维度以不同距离移动张量的不同通道,然后沿“通道”维度执行求和以将不同维度合并为一个。具体来说,给定大小为 (B,C,H,W) 和步长为 S 的张量 x,其中 B、C、H、W 分别表示批量大小、通道数、高度和宽度,即第 i 个通道x 平移距离 (i-1)*S,然后将 C 个通道求和为 1。
这是一个一维玩具示例。 假设我有一个 3 通道张量 x as
x = torch.tensor(
[[1,1,1],
[2,2,2],
[3,3,3]]
)
现在我将步长设置为1,然后对张量执行移位为
x_shifted = torch.tensor(
[[1,1,1,0,0],
[0,2,2,2,0],
[0,0,3,3,3]]
)
这里,第一个通道移动了距离 0,第二个通道移动了距离 1,第三个通道移动了距离 2。 最后,将所有三个通道相加并合并为一个通道
y = torch.tensor(
[[1,3,6,5,3]]
)
问题:
我已经实施了原始流程。以下代码中的 2D 图像张量:
import torch
import torch.nn.functional as F
from time import time
#############################################
# Parameters
#############################################
B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')
seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#############################################
# Method 1
#############################################
alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
alpha[..., (i*S):(i*S+W)] += 1
def A(x, mask):
z = x * mask
y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
for i in range(C):
y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
return y
def A_pinv(y, mask):
z = y / alpha.to(y.device)
x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
return x
#############################################
# Method 2
#############################################
kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
for i in range(C):
kernel[:, C-i-1, :, i*S] = 1
def A_fast(x, mask):
return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))
def A_pinv_fast(y, mask):
return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask
#############################################
# Test 1
#############################################
start_time = time()
MAE = 0
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A(x, mask)
x_init = A_pinv(y, mask)
y_init = A(x_init, mask)
MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 1')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)
#############################################
# Test 2
#############################################
start_time = time()
MAE = 0
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A_fast(x, mask)
x_init = A_pinv_fast(y, mask)
y_init = A_fast(x_init, mask)
MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 2')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)
这里,
Method 1
使用for
循环实现该过程,而我相信Method 2
通过使用2D卷积运算等效地实现该过程。
更具体地说,函数
A
和A_pinv
分别实现了转发压缩过程及其“伪逆”。他们在 Method 2
中的“快速”版本预计通过并行实现会更快。
但是,当我运行代码时,我发现
Method 1
仍然比速度领先的Method 2
快很多。我想问一下:
我们能否有效加速
Method 1
?更具体地说,我想知道我们是否可以并行化 for
循环,以使“Shift+Summation”过程更快?
加速“方法 1”中“Shift+Summation”过程的一种方法是并行化 for 循环。在您的情况下,您可以通过利用“torch.nn.parallel.data_parallel”函数,使用PyTorch的内置并行处理功能来实现此目的。此函数可跨多个 GPU 或 CPU 核心并行计算。
import torch
import torch.nn.functional as F
from torch.nn.parallel import data_parallel
from time import time
#############################################
# Parameters
#############################################
B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')
seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#############################################
# Method 1 Edited***
#############################################
alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
alpha[..., (i*S):(i*S+W)] += 1
def A(x, mask): # Edited From Here.
z = x * mask
y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
def process_channel(i):
y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
data_parallel(process_channel, range(C))
return y
def A_pinv(y, mask):
z = y / alpha.to(y.device)
x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
return x # End here.
通过使用 data_parallel 函数,“A”函数中的 for 循环将在可用设备上自动并行化,从而加快执行速度。如果您有多个 GPU,并行化的好处将更加明显。 (你用的是哪个GPU?)
您可以在官方文档中了解有关此主题的更多信息:https://pytorch.org/docs/stable/nn.function.html(这是页面上的最后一项)。