3,4轴火炬的矩阵乘法

问题描述 投票:2回答:1

我有两个张量分别为a(16,8,8,64)b(64,64)的张量。假设我将a的最后一个维度提取到另一个列向量c中,我想计算matmul(matmul(c.T, b), c)。我希望在a的前3个维度中的每个维度中都完成此操作。也就是说,最终产品的形状应为(16,8,8,1)。如何在pytorch中实现此目标?

python pytorch matrix-multiplication
1个回答
0
投票

可以执行以下操作:

row_vec = a[:, :, :, None, :].float()
col_vec = a[:, :, :, :, None].float()
b = (b[None, None, None, :, :]).float()
prod = torch.matmul(torch.matmul(row_vec, b), col_vec)
© www.soinside.com 2019 - 2024. All rights reserved.