Column Wise Dot-Product torch.einsum 不匹配 torch.sum(torch.mul(), axis=0)

问题描述 投票:0回答:0

我正在尝试在两个张量的列之间执行点积。我试图以最有效的方式做到这一点。但是,我的两种方法不匹配。

我使用

torch.sum(torch.mul(a, b), axis=0)
的第一个方法给了我预期的结果,
torch.einsum('ji, ji -> i', a, b)
(取自 Efficient method to compute the row-wise dot product of two square matrices of the same size in PyTorch)没有。可重现的代码如下:

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

a = torch.randn(3,1, dtype=torch.float).to(device)
b = torch.randn(3,4, dtype=torch.float).to(device)

print(f"a : \n{a}\n")
print(f"b : \n{b}\n")
print(f"Expected:    {a[0,0]*b[0,0] + a[1,0]*b[1,0] + a[2,0]*b[2,0]}")

c = torch.sum(torch.mul(a, b), axis=0)
print(f"sum and mul: {c[0].item()}")

d = torch.einsum('ji, ji -> i', a, b)
print(f"einsum:      {d[0].item()}\n")

print(torch.eq(c,d))

输出为:

注意事项: 在 CPU 上(我所做的只是删除

.to(device)
)最后一行
torch.eq(c,d)
都是真的但是,我需要张量在 GPU 上。

还有一些种子如

torch.manual_seed(100)
张量是相等的...

我觉得它必须是

einsum
的东西,因为我可以通过其他方式得到我预期的答案。

python pytorch tensor
© www.soinside.com 2019 - 2024. All rights reserved.