Pytorch - 从 2 个张量的切片进行矩阵乘法

问题描述 投票:0回答:1

如果有2个以下大小的张量。

A = [N x 长 x T]

B = [N x T x K]

然后我想对两个张量的切片进行矩阵乘法。就像下面这样。

matmul_slice = A[0,:,:] @ B[0,:,:] = [L x T] @ [T x K] = [L x K]

然后我想沿着维度 = 0 做 N 次。 这样我最终得到大小为 [N,L,K] 的最终矩阵

我不想使用 N 上的循环,因为它会减慢计算速度。我一直在玩 torch.matmul 和 einsum,但我无法得到正确的答案......

如何以紧凑的方式实现这一目标?

pytorch tensor einsum matmul
1个回答
0
投票

torch.bmm
是您所需要的,尽管
torch.matmul
在您的情况下应该是等效的。我认为你应该重新检查你的计算。

© www.soinside.com 2019 - 2024. All rights reserved.