`torch.einsum` API 是如何工作的?

问题描述 投票:0回答:1

torch.einsum
API 是如何工作的?

我试图理解如何

torch.einsum("ac,bc->ab",norm_max_func_embedding,norm_nl_embedding)
正在计算相似度?

我知道这是对张量进行操作。

我认为“ac”指定尺寸为(a,c)的张量。但是“bc->ab”正在做什么。还有它是如何计算相似度的。我认为相似度可以通过余弦相似度或欧几里得距离来计算。

# Encode maximum function
func = "def f(a,b): if a>b: return a else return b"
tokens_ids = model.tokenize([func],max_length=512,mode="<encoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings,max_func_embedding = model(source_ids)

# Encode minimum function
func = "def f(a,b): if a<b: return a else return b"
tokens_ids = model.tokenize([func],max_length=512,mode="<encoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings,min_func_embedding = model(source_ids)

norm_max_func_embedding = torch.nn.functional.normalize(max_func_embedding, p=2, dim=1)
norm_min_func_embedding = torch.nn.functional.normalize(min_func_embedding, p=2, dim=1)
norm_nl_embedding = torch.nn.functional.normalize(nl_embedding, p=2, dim=1)

max_func_nl_similarity = torch.einsum("ac,bc->ab",norm_max_func_embedding,norm_nl_embedding)
min_func_nl_similarity = torch.einsum("ac,bc->ab",norm_min_func_embedding,norm_nl_embedding)

我指的是这个github存储库:https://github.com/microsoft/CodeBERT/tree/master/UniXcoder

它测量什么样的相似度?

非常感谢任何帮助或文档指示。

python pytorch torch embedding word-embedding
1个回答
0
投票

ac,bc->ab
指的是两个输入和一个输出张量。左侧标记输入的维度,以逗号分隔。右侧显示输出张量的维度(
->
之后的所有内容)。

  • 第一个输入张量的尺寸为
    a
    c
  • 第二个输入张量的尺寸为
    b
    c

由于

c
被两者共享,因此两个张量的维度必须相同。其他两个维度是“自由”的。

  • 输出张量的尺寸为
    a
    b

由于

c
没有出现在输出中,我们沿着该维度进行和积。如果
c
确实出现在输出中,那么它只是一个产品。例如
ab,ab->ab
是两个形状相同的矩阵的逐元素乘积。

尺寸的顺序很重要。例如

ab->ba
是转置操作。

© www.soinside.com 2019 - 2024. All rights reserved.