Pytorch 并行化 linalg.solve() 循环

问题描述 投票:0回答:1

目标: 我有一个函数调用

torch.linalg.solve()
,我想尽可能快地运行。

设置:我有一个

input_array
(尺寸
50x100x100
)。我有一个
host_array
(大小
100x100x100
?。我的函数
solver
输入
input_array[:,i,j]
并输出大小为 100 的向量以存储在
host_array[:,i,j]
中。我正在对
input_array
的所有行和列运行嵌套循环填充
host_array

问题:运行速度很慢,特别是考虑到我的实际情况,每次调用函数都需要一秒钟。 我正在使用嵌套循环运行它,我想知道通过并行化我的函数是否会更快?

示例代码:

import torch
from tqdm import tqdm

# Create host_array and input_array with random data
host_array = torch.zeros(100, 500, 500)
input_array = torch.randn(50, 500, 500)

# Create a dummy coefficient matrix A (50x100)
A = torch.randn(50, 100)

# Define your function to solve for input_array[:, i, j] and update host_array[:, i, j]
def solver(input_vector, A):

    # Solve the linear system of equation
    solution = torch.linalg.solve(A.T@A, A.T@input_vector)

    return solution

# Calculate total runtime
total_iterations = int(host_array.shape[1]*host_array.shape[2])
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)

# Iterate through the input_array
for i in range(host_array.shape[1]):
    for j in range(host_array.shape[2]):
        host_array[:,i,j] = solver(input_array[:,i,j], A)
        progress_bar.update(1)
python optimization pytorch parallel-processing linear-algebra
1个回答
1
投票

您可以利用

torch.linalg.solve()
的广播功能来获得显着的加速 - 请参阅
# proposed solution
部分,特别是下面我的代码中的函数
solver_batch()
。我注释了对输入进行必要的重塑(挤压、解压和排列)而产生的形状。

from datetime import datetime
import torch

torch.manual_seed(42)  # Make result reproducible

B, C = 500, 500
M, N = 100, 50

input_array = torch.randn(N, B, C)
A = torch.randn(N, M)

# Proposed solution

t_start = datetime.now()
def solver_batch(inpt, a):
    at_a = a.T @ a  # MxM
    inpt = inpt.permute(1, 2, 0).unsqueeze(-1)  # BxCxNx1
    at_input = a.T @ inpt  # BxCxMx1
    
    result = torch.linalg.solve(at_a, at_input)  # BxCxMx1
    return result.squeeze(-1).permute(2, 0, 1)  # MxBxC

proposed_result = solver_batch(input_array, A)
t_stop = datetime.now()
print(f"Proposed solution took {(t_stop - t_start).total_seconds():.2f}s")

# Previous solution

t_start = datetime.now()
host_array = torch.zeros(M, B, C)  # Will hold the result

def solver(input_vector, A):
    return torch.linalg.solve(A.T@A, A.T@input_vector)

for i in range(host_array.shape[1]):
    for j in range(host_array.shape[2]):
        host_array[:,i,j] = solver(input_array[:,i,j], A)
t_stop = datetime.now()
print(f"Previous solution took {(t_stop - t_start).total_seconds():.2f}s")

# Check results

left = (A.T @ A) @ host_array.permute(1, 2, 0).unsqueeze(-1)
right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
print("(A.T @ A) @ host_array      == A.T @ input_array?",
      torch.allclose(left, right, atol=1e-3))

left = (A.T @ A) @ proposed_result.permute(1, 2, 0).unsqueeze(-1)
right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
print("(A.T @ A) @ proposed_result == A.T @ input_array?",
      torch.allclose(left, right, atol=1e-3))

在我的机器上,我得到:

>>> Proposed solution took 0.62s
>>> Previous solution took 21.74s
>>> (A.T @ A) @ host_array      == A.T @ input_array? True
>>> (A.T @ A) @ proposed_result == A.T @ input_array? True

请注意,虽然

host_array
proposed_result
都持有有效解,但它们不一定持有相同的解(事实上,对于给定的随机种子,它们并不相同)。如果我理解正确的话,这是因为
torch.linalg.solve()
的结果是唯一的,当且仅当它的第一个参数(在我们的例子中为
A.T @ A
)是可逆的,但它看起来不一定是可逆的。如果您想检查您和我的方法确实为可逆矩阵产生相同的解决方案(受数值误差影响),您可以按照
this 方法
自己构建一个 MxM 可逆矩阵,并将
A.T @ A
替换为它测试目的。

另请注意,当将结果与

torch.allclose()
进行比较时,我必须非常慷慨(
atol=1e-3
),因为似乎会产生相当多的数值错误。

© www.soinside.com 2019 - 2024. All rights reserved.