如何使用pytorch匹配列表中张量和值的索引?

问题描述 投票:0回答:1

我想匹配列表中张量的索引。 我正在尝试使用 Pytorch 进行链接预测。 在这个过程中,我需要通过将索引映射到字典来将其转换为名称。 为此,我将字典和掩码设置为张量,但它返回了意外的索引。

inv_entity_dict = {v: k for k, v in entity_dict.items()}
inv_entity_dict
#{0: 'TMEM35A',
# 1: 'FHL5',
# 2: 'Sirolimus',
# 3: 'TMCO2',
# 4: 'RNF123',
# 5: 'SMURF2',
# 6: 'SSH3',
# 7: 'PSMA4',
# 8: 'SOD3',
# 9: 'SCOC',
# 10: 'Cysteamine',
# 11: 'TOX',
#...}

nonzero[0:10]
#array([ 0,  1,  3,  4,  5,  6,  7,  8,  9, 11])

运行代码后,它返回了意外的结果,因为西罗莫司(idx==2)不在非零数组中,不应该与名称匹配。

for i in range(1):
    raw_probs = (z[i][nonzero[0:10]] @ z[i][nonzero[0:10]].t()).sigmoid()
    filtered_probs = pd.DataFrame((raw_probs>0.9).nonzero(as_tuple=False).cpu().numpy(), columns=['Gene1', 'Gene2'])
    filtered_probs['prob'] = raw_probs[(raw_probs>0.9)].cpu().detach().numpy()
    filtered_probs_name = map_id2gene(filtered_probs, inv_entity_dict) #converting func.

#Expected result
#   Gene1   Gene2   prob
#67 TOX TOX 1.0
#0  TMEM35A TMEM35A 1.0
#1  TMEM35A FHL5    1.0
#2  TMEM35A RNF123  1.0
#52 SCOC    TMEM35A 1.0

#Wrong
#   Gene1   Gene2   prob
#67 SCOC    SCOC    1.0
#0  TMEM35A TMEM35A 1.0
#1  TMEM35A FHL5    1.0
#2  TMEM35A Sirolimus   1.0
#52 SOD3    TMEM35A 1.0

我猜初始化的

raw_probs
索引直接进入了转换过程。

raw_prob
#tensor([[1.0000e+00, ..., 1.0000e+00], #real index: 0
#        [1.0000e+00, ..., 1.0000e+00], #real index: 1
#        [1.0000e+00, ..., 1.0000e+00], #real index: 3, but considered to 2
#        [1.0000e+00, ..., 1.0000e+00], #real index: 4, but considered to 3, ...
#        [1.0000e+00, ..., 1.0000e+00], #real index: 5
#        [1.0000e+00, ..., 1.0000e+00], #real index: 6
#        [0.0000e+00, ..., 0.0000e+00], #real index: 7
#        [0.0000e+00, ..., 4.4097e-36], #real index: 8
#        [1.0000e+00, ..., 1.0000e+00], #real index: 9
#        [1.0000e+00, ..., 1.0000e+00] #real index: 11, but considered to 9], device='cuda:0')

在这种情况下,如何根据

inv_entity_dict
nonzero
列表匹配正确的 id 和名称?

python pytorch
1个回答
0
投票

好吧,我只是像这样添加新词典:

d = dict(enumerate(map(int, nonzero)))

filtered_probs['Gene1'] = filtered_probs.apply(lambda L: d[L.iloc[0]], axis=1)
filtered_probs['Gene2'] = filtered_probs.apply(lambda L: d[L.iloc[1]], axis=1)

然后返回正确的结果。
如果您有更好的意见,欢迎评论!

© www.soinside.com 2019 - 2024. All rights reserved.