目标: 我有一个函数调用
torch.linalg.solve()
,我想尽可能快地运行。
设置:我有一个
input_array
(尺寸50x100x100
)。我有一个 host_array
(大小 100x100x100
?。我的函数 solver
输入 input_array[:,i,j]
并输出大小为 100 的向量以存储在 host_array[:,i,j]
中。我正在对 input_array
的所有行和列运行嵌套循环填充host_array
。
问题:运行速度很慢,特别是考虑到我的实际情况,每次调用函数都需要一秒钟。 我正在使用嵌套循环运行它,我想知道通过并行化我的函数是否会更快?
示例代码:
import torch
from tqdm import tqdm
# Create host_array and input_array with random data
host_array = torch.zeros(100, 500, 500)
input_array = torch.randn(50, 500, 500)
# Create a dummy coefficient matrix A (50x100)
A = torch.randn(50, 100)
# Define your function to solve for input_array[:, i, j] and update host_array[:, i, j]
def solver(input_vector, A):
# Solve the linear system of equation
solution = torch.linalg.solve(A.T@A, A.T@input_vector)
return solution
# Calculate total runtime
total_iterations = int(host_array.shape[1]*host_array.shape[2])
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)
# Iterate through the input_array
for i in range(host_array.shape[1]):
for j in range(host_array.shape[2]):
host_array[:,i,j] = solver(input_array[:,i,j], A)
progress_bar.update(1)
torch.linalg.solve()
的广播功能来获得显着的加速 - 请参阅 # proposed solution
部分,特别是下面我的代码中的函数 solver_batch()
。我注释了对输入进行必要的重塑(挤压、解压和排列)而产生的形状。
from datetime import datetime
import torch
torch.manual_seed(42) # Make result reproducible
B, C = 500, 500
M, N = 100, 50
input_array = torch.randn(N, B, C)
A = torch.randn(N, M)
# Proposed solution
t_start = datetime.now()
def solver_batch(inpt, a):
at_a = a.T @ a # MxM
inpt = inpt.permute(1, 2, 0).unsqueeze(-1) # BxCxNx1
at_input = a.T @ inpt # BxCxMx1
result = torch.linalg.solve(at_a, at_input) # BxCxMx1
return result.squeeze(-1).permute(2, 0, 1) # MxBxC
proposed_result = solver_batch(input_array, A)
t_stop = datetime.now()
print(f"Proposed solution took {(t_stop - t_start).total_seconds():.2f}s")
# Previous solution
t_start = datetime.now()
host_array = torch.zeros(M, B, C) # Will hold the result
def solver(input_vector, A):
return torch.linalg.solve(A.T@A, A.T@input_vector)
for i in range(host_array.shape[1]):
for j in range(host_array.shape[2]):
host_array[:,i,j] = solver(input_array[:,i,j], A)
t_stop = datetime.now()
print(f"Previous solution took {(t_stop - t_start).total_seconds():.2f}s")
# Check results
left = (A.T @ A) @ host_array.permute(1, 2, 0).unsqueeze(-1)
right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
print("(A.T @ A) @ host_array == A.T @ input_array?",
torch.allclose(left, right, atol=1e-3))
left = (A.T @ A) @ proposed_result.permute(1, 2, 0).unsqueeze(-1)
right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
print("(A.T @ A) @ proposed_result == A.T @ input_array?",
torch.allclose(left, right, atol=1e-3))
在我的机器上,我得到:
>>> Proposed solution took 0.62s
>>> Previous solution took 21.74s
>>> (A.T @ A) @ host_array == A.T @ input_array? True
>>> (A.T @ A) @ proposed_result == A.T @ input_array? True
请注意,虽然
host_array
和proposed_result
都持有有效解,但它们不一定持有相同的解(事实上,对于给定的随机种子,它们并不相同)。如果我理解正确的话,这是因为 torch.linalg.solve()
的结果是唯一的,当且仅当它的第一个参数(在我们的例子中为 A.T @ A
)是可逆的,但它看起来不一定是可逆的。另请注意,当将结果与 torch.allclose()
进行比较时,我必须非常慷慨 (atol=1e-3
),因为似乎会产生相当多的数值错误。