在pytorch中并行乘以N个矩阵

问题描述 投票:0回答:1

我有一个大小为

L
A_1, A_2, ...
矩阵
N x N
列表。我想计算乘积
A_1 @ A_2 @ ... A_L
。我可以通过使用
torch.matmul
迭代计算乘积来计算该值。但是,如果我能够独立且并行地计算
(A_1 @ A_2) @ (A_3 @ A_4) ...
中的组件,则可以进一步并行化。我如何在 PyTorch 中执行此操作?

pytorch parallel-processing matrix-multiplication
1个回答
0
投票

矩阵乘法不可交换。乘法步骤的顺序很重要。请参阅此处

但是,尽管存在数学上的不等式,如果您想这样做,您可以通过二的倍数的相反索引来细分列表。令

A
L
矩阵的原始列表:

A1 = torch.stack(A[::2])
A2 = torch.stack(A[1::2])

然后您可以进行批量矩阵乘法,如果

A1
A1
的长度不同,则可能会省略
A2
的最后一个元素。

# multiply - output is of size [L//2,N,N]
A_out = torch.bmm(A1[:L2.shape[0]],A2) 

# concatenate the last element from L1 if necessary
if A1.shape[0] != A2.shape[0]:
   A_out = torch.cat((A_out,A1[-1,:,:]),dim = 0)
© www.soinside.com 2019 - 2024. All rights reserved.