并行化在 GPU 上运行的 pytorch 函数

问题描述 投票:0回答:1

目标:我有一个由循环调用的函数,该函数输入一个一维张量和一个二维张量。我在这个函数中使用了

torch.linalg.solve()
。我想并行化循环以优化运行时间。

设置:我有 3 个主要张量:

  1. input_tensor
    :尺寸50x100x100
  2. host_tensor
    :尺寸100x100x100
  3. A
    :尺寸50x100(设计矩阵)

input_tensor
具有 100x100
input_vector
,所有长度均为 50。它们还具有我掩码的不同数量的 NaN,因此掩码的
input_vector
长度小于或等于 50。请注意,设计矩阵
A
也将被屏蔽并具有大小(屏蔽 x 100)。

由于每个

input_vector
A
具有不同的掩码长度,因此该函数需要逐点运行。

问题:有没有办法让下面的代码更快?我如何处理每次迭代时具有不同大小的设计矩阵

A
input_vector

重要提示: NaN 不能用 0 替换,因为这会破坏线性求解过程。作为背景,我问了一个关于类似过程的问题here

代码:

import torch
from tqdm import tqdm
import numpy as np
from datetime import datetime

# Create "device" so we can migrate the tensors to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Set the seed for reproducibility
torch.manual_seed(42) 

# Set shapes to generate tensors
B, C = 500, 500
M, N = 100, 50

# Generate tensors
input_tensor = torch.randn(N, B, C)
host_tensor = torch.randn(M, B, C)
A = torch.randn(N, M)

# --- Here we input random NaNs in the input_tensor to simulate missing data --- #
# Define the probability of inserting NaN at each element
probability = 0.2  # You can adjust this as needed

# Generate random indices based on the probability
shape = input_tensor.shape
random_indices = torch.rand(shape) < probability

# Replace the selected indices with NaN values
input_tensor[random_indices] = float('nan')

# --- Migrate matrices to GPU --- #
A = A.to(device)
input_tensor = input_tensor.to(device)
host_tensor = host_tensor.to(device)
A = A.to(device)

t_start = datetime.now()
# --- Function that creates a vector size M from input_vector (size N) and A
def solver(input_vector, A):

    # We create a mask to reduce the row size of A: rows where input_vector is NaN are not considered in the solver
    mask = ~torch.isnan(input_vector)

    # Mask the vector
    input_vector_masked = input_vector[mask]

    # Mask the array
    A_masked = A[mask]
    A_trans = A_masked.T

    # Solve the linear system of equation: A.TA = A.Tvec_Obs
    return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)


# --- Iterate through each vector of the input_tensor --- #

# Define the total number of iterations
total_iterations = B*C
# Create a tqdm progress bar
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)

# Iterate through every cell of input_array
for i in range(host_tensor.shape[1]):
    for j in range(host_tensor.shape[2]):
        host_tensor[:,i,j] = solver(input_tensor[:,i,j], A)
        progress_bar.update(1)  # Update the progress bar
t_stop = datetime.now()


print(f"Inversion took {(t_stop - t_start).total_seconds():.2f}s")
python optimization pytorch gpu linear-algebra
1个回答
0
投票

我在这里得到的答案有点不满意。但让我们一步一步来吧。

归零
nan
s ==下降
nan
s

首先,您可以用零替换

nan
。举个例子:假设你有一个向量
v
和一个矩阵
A
,给出为

v = [v1 v2 v3]  # N elements
A = [[a11 a12 a13]  # NxM elements
     [a21 a22 a23]
     [a31 a32 a33]]

现在,假设

v2 = nan
,因此需要被抑制。

您当前在

solver()
中所做的是将
nan
的非
v
元素获取为
m
,将
A
的相应行获取为
M
,然后计算
A_for_solving = M.T @ M
B_for_solving = M.T @ v
,即

m = [v1 v3]  # Masked v (n < N elements)
M = [[a11 a12 a13]  # Masked A (nxM elements)
     [a31 a32 a33]]
A_for_solving = M.T @ M  # MxM elements
B_for_solving = M.T @ m  # M elements
result = linalg.solve(A_for_solving, B_for_solving)

这里你应该注意到两件事:

  1. A_for_solving
    B_for_solving
    的形状始终保持不变,无论
    v
    中有多少元素(以及
    A
    中的行)被删除:
    A_for_solving
    始终是一个M×M矩阵并且
     B_for_solving
    始终是 M 元素向量。这暗示我们实际上仍然可以并行计算。

  2. 更重要的是,如果您将

    nan
    中的
    v
    以及
    A
    中的相应行替换为零,您将在
    A_for_solving
    B_for_solving
    中产生完全相同的值!

    换句话说,您可以执行以下操作:

    z = [v1 0 v3]  # Zeroed v
    Z = [[a11 a12 a13]  # Zeroed A
         [  0   0   0]
         [a31 a32 a33]]
    A_for_solving = Z.T @ Z
    B_for_solving = Z.T @ z
    result = linalg.solve(A_for_solving, B_for_solving)
    

    …您将获得与之前完全相同的

    linalg.solve()
    输入!

您可以通过扩展当前代码以进行测试,轻松地使用当前代码进行检查,如下所示:

def solver(input_vector, A):

    mask = ~torch.isnan(input_vector)
    input_vector_masked = input_vector[mask]

    A_masked = A[mask]
    A_trans = A_masked.T
    
    # Start sanity check: nan-zeroing is the same as nan-dropping
    A_zeroed = A.clone(); A_zeroed[~mask] = 0
    input_vector_zeroed = input_vector.clone(); input_vector_zeroed[~mask] = 0
    assert torch.allclose(A_masked.T @ A_masked,
                          A_zeroed.T @ A_zeroed, atol=1e-5)
    assert torch.allclose(A_masked.T @ input_vector_masked,
                          A_zeroed.T @ input_vector_zeroed, atol=1e-5)
    # End sanity check
    
    return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)

批量计算

如果我们使用归零方法,我们可以再次并行化我们的代码,因为我们现在再次为每个掩码提供相同大小的输入。相应的函数如下所示:

def solver_batch(inpt, a):
    inpt = inpt.permute(1, 2, 0).unsqueeze(-1)  # BxCxNx1
    mask = torch.isnan(inpt)  # CAUTION: True for NaNs, unlike `mask` in the question!
    a_zeroed = a.repeat(*inpt.shape[:2], 1, 1)  # BxCxNxM
    a_zeroed[mask.expand(-1, -1, -1, a.shape[-1])] = 0
    at_a = a_zeroed.transpose(-2, -1) @ a_zeroed  # BxCxMxM
    inpt_zeroed = inpt.clone()
    inpt_zeroed[mask] = 0
    at_input = a_zeroed.transpose(-2, -1) @ inpt_zeroed  # BxCxMx1
    result = torch.linalg.solve(at_a, at_input)
    return result.squeeze(-1).permute(2, 0, 1)  # MxBxC

注意事项

批量解决方案与我在您之前的问题中发布的答案非常相似。但有两个注意事项:

注意事项 1:内存使用情况

由于我们现在需要一个不同的矩阵

A
,因此每个输入向量都需要
A.T @ A
,因此在给定的示例中,我们最终得到大小为 500×500×100×100 的张量
at_a
。这是巨大(在本例中是一个包含 25 亿个元素的张量)。就我而言,它不适合 GPU,所以我要做的就是处理垃圾中的输入张量:

chunk_size = 50  # TODO: adjust chunk size for your hardware
for lo in range(0, input_tensor.shape[-1], chunk_size):
    chunk_result = solver_batch(input_tensor[..., lo:lo+chunk_size], A)
    host_tensor[..., lo:lo+chunk_size] = chunk_result

这仍然比按元素处理输入要快得多。

警告 2:数值不稳定

我尝试使用以下 for 循环对结果进行健全性检查,类似于我之前的答案中的健全性检查:

for i in range(host_tensor.shape[1]):
    for j in range(host_tensor.shape[2]):
        input_vec = input_tensor[..., i, j]
        res_vec = host_tensor[..., i, j]
        mask = ~torch.isnan(input_vec)
        M = A[mask]
        assert torch.allclose((M.T @ M) @ res_vec, M.T @ input_vec[mask], atol=1e-3)

我们在这里检查的是,如果

X = solve(A, B)
,那么
A @ X = B
应该成立。然而,给定的数据似乎并非如此,无论是我的方法还是你的方法。我不知道这是否是数值不稳定的问题(我的数学技能缺乏)或者我是否犯了一些愚蠢的错误。

© www.soinside.com 2019 - 2024. All rights reserved.