目标:我有一个由循环调用的函数,该函数输入一个一维张量和一个二维张量。我在这个函数中使用了
torch.linalg.solve()
。我想并行化循环以优化运行时间。
设置:我有 3 个主要张量:
input_tensor
:尺寸50x100x100host_tensor
:尺寸100x100x100A
:尺寸50x100(设计矩阵)input_tensor
具有 100x100 input_vector
,所有长度均为 50。它们还具有我掩码的不同数量的 NaN,因此掩码的 input_vector
的 长度小于或等于 50。请注意,设计矩阵 A
也将被屏蔽并具有大小(屏蔽 x 100)。
由于每个
input_vector
和 A
具有不同的掩码长度,因此该函数需要逐点运行。
问题:有没有办法让下面的代码更快?我如何处理每次迭代时具有不同大小的设计矩阵
A
和 input_vector
?
重要提示: NaN 不能用 0 替换,因为这会破坏线性求解过程。作为背景,我问了一个关于类似过程的问题here。
代码:
import torch
from tqdm import tqdm
import numpy as np
from datetime import datetime
# Create "device" so we can migrate the tensors to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Set the seed for reproducibility
torch.manual_seed(42)
# Set shapes to generate tensors
B, C = 500, 500
M, N = 100, 50
# Generate tensors
input_tensor = torch.randn(N, B, C)
host_tensor = torch.randn(M, B, C)
A = torch.randn(N, M)
# --- Here we input random NaNs in the input_tensor to simulate missing data --- #
# Define the probability of inserting NaN at each element
probability = 0.2 # You can adjust this as needed
# Generate random indices based on the probability
shape = input_tensor.shape
random_indices = torch.rand(shape) < probability
# Replace the selected indices with NaN values
input_tensor[random_indices] = float('nan')
# --- Migrate matrices to GPU --- #
A = A.to(device)
input_tensor = input_tensor.to(device)
host_tensor = host_tensor.to(device)
A = A.to(device)
t_start = datetime.now()
# --- Function that creates a vector size M from input_vector (size N) and A
def solver(input_vector, A):
# We create a mask to reduce the row size of A: rows where input_vector is NaN are not considered in the solver
mask = ~torch.isnan(input_vector)
# Mask the vector
input_vector_masked = input_vector[mask]
# Mask the array
A_masked = A[mask]
A_trans = A_masked.T
# Solve the linear system of equation: A.TA = A.Tvec_Obs
return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)
# --- Iterate through each vector of the input_tensor --- #
# Define the total number of iterations
total_iterations = B*C
# Create a tqdm progress bar
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)
# Iterate through every cell of input_array
for i in range(host_tensor.shape[1]):
for j in range(host_tensor.shape[2]):
host_tensor[:,i,j] = solver(input_tensor[:,i,j], A)
progress_bar.update(1) # Update the progress bar
t_stop = datetime.now()
print(f"Inversion took {(t_stop - t_start).total_seconds():.2f}s")
我在这里得到的答案有点不满意。但让我们一步一步来吧。
nan
s ==下降nan
s首先,您可以用零替换
nan
。举个例子:假设你有一个向量 v
和一个矩阵 A
,给出为
v = [v1 v2 v3] # N elements
A = [[a11 a12 a13] # NxM elements
[a21 a22 a23]
[a31 a32 a33]]
现在,假设
v2 = nan
,因此需要被抑制。
您当前在
solver()
中所做的是将nan
的非v
元素获取为m
,将A
的相应行获取为M
,然后计算A_for_solving = M.T @ M
和B_for_solving = M.T @ v
,即
m = [v1 v3] # Masked v (n < N elements)
M = [[a11 a12 a13] # Masked A (nxM elements)
[a31 a32 a33]]
A_for_solving = M.T @ M # MxM elements
B_for_solving = M.T @ m # M elements
result = linalg.solve(A_for_solving, B_for_solving)
这里你应该注意到两件事:
A_for_solving
和B_for_solving
的形状始终保持不变,无论v
中有多少元素(以及A
中的行)被删除:A_for_solving
始终是一个M×M矩阵并且 B_for_solving
始终是 M 元素向量。这暗示我们实际上仍然可以并行计算。
更重要的是,如果您将
nan
中的 v
以及 A
中的相应行替换为零,您将在 A_for_solving
和 B_for_solving
中产生完全相同的值!
换句话说,您可以执行以下操作:
z = [v1 0 v3] # Zeroed v
Z = [[a11 a12 a13] # Zeroed A
[ 0 0 0]
[a31 a32 a33]]
A_for_solving = Z.T @ Z
B_for_solving = Z.T @ z
result = linalg.solve(A_for_solving, B_for_solving)
…您将获得与之前完全相同的
linalg.solve()
输入!
您可以通过扩展当前代码以进行测试,轻松地使用当前代码进行检查,如下所示:
def solver(input_vector, A):
mask = ~torch.isnan(input_vector)
input_vector_masked = input_vector[mask]
A_masked = A[mask]
A_trans = A_masked.T
# Start sanity check: nan-zeroing is the same as nan-dropping
A_zeroed = A.clone(); A_zeroed[~mask] = 0
input_vector_zeroed = input_vector.clone(); input_vector_zeroed[~mask] = 0
assert torch.allclose(A_masked.T @ A_masked,
A_zeroed.T @ A_zeroed, atol=1e-5)
assert torch.allclose(A_masked.T @ input_vector_masked,
A_zeroed.T @ input_vector_zeroed, atol=1e-5)
# End sanity check
return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)
如果我们使用归零方法,我们可以再次并行化我们的代码,因为我们现在再次为每个掩码提供相同大小的输入。相应的函数如下所示:
def solver_batch(inpt, a):
inpt = inpt.permute(1, 2, 0).unsqueeze(-1) # BxCxNx1
mask = torch.isnan(inpt) # CAUTION: True for NaNs, unlike `mask` in the question!
a_zeroed = a.repeat(*inpt.shape[:2], 1, 1) # BxCxNxM
a_zeroed[mask.expand(-1, -1, -1, a.shape[-1])] = 0
at_a = a_zeroed.transpose(-2, -1) @ a_zeroed # BxCxMxM
inpt_zeroed = inpt.clone()
inpt_zeroed[mask] = 0
at_input = a_zeroed.transpose(-2, -1) @ inpt_zeroed # BxCxMx1
result = torch.linalg.solve(at_a, at_input)
return result.squeeze(-1).permute(2, 0, 1) # MxBxC
批量解决方案与我在您之前的问题中发布的答案非常相似。但有两个注意事项:
由于我们现在需要一个不同的矩阵
A
,因此每个输入向量都需要 A.T @ A
,因此在给定的示例中,我们最终得到大小为 500×500×100×100 的张量 at_a
。这是巨大(在本例中是一个包含 25 亿个元素的张量)。就我而言,它不适合 GPU,所以我要做的就是处理垃圾中的输入张量:
chunk_size = 50 # TODO: adjust chunk size for your hardware
for lo in range(0, input_tensor.shape[-1], chunk_size):
chunk_result = solver_batch(input_tensor[..., lo:lo+chunk_size], A)
host_tensor[..., lo:lo+chunk_size] = chunk_result
这仍然比按元素处理输入要快得多。
我尝试使用以下 for 循环对结果进行健全性检查,类似于我之前的答案中的健全性检查:
for i in range(host_tensor.shape[1]):
for j in range(host_tensor.shape[2]):
input_vec = input_tensor[..., i, j]
res_vec = host_tensor[..., i, j]
mask = ~torch.isnan(input_vec)
M = A[mask]
assert torch.allclose((M.T @ M) @ res_vec, M.T @ input_vec[mask], atol=1e-3)
我们在这里检查的是,如果
X = solve(A, B)
,那么A @ X = B
应该成立。然而,给定的数据似乎并非如此,无论是我的方法还是你的方法。我不知道这是否是数值不稳定的问题(我的数学技能缺乏)或者我是否犯了一些愚蠢的错误。