如果有2个以下大小的张量。
A = [N x 长 x T]
B = [N x T x K]
然后我想对两个张量的切片进行矩阵乘法。就像下面这样。
matmul_slice = A[0,:,:] @ B[0,:,:] = [L x T] @ [T x K] = [L x K]
然后我想沿着维度 = 0 做 N 次。 这样我最终得到大小为 [N,L,K] 的最终矩阵
我不想使用 N 上的循环,因为它会减慢计算速度。我一直在玩 torch.matmul 和 einsum,但我无法得到正确的答案......
如何以紧凑的方式实现这一目标?
torch.bmm
是您所需要的,尽管 torch.matmul
在您的情况下应该是等效的。我认为你应该重新检查你的计算。