我想匹配列表中张量的索引。 我正在尝试使用 Pytorch 进行链接预测。 在这个过程中,我需要通过将索引映射到字典来将其转换为名称。 为此,我将字典和掩码设置为张量,但它返回了意外的索引。
inv_entity_dict = {v: k for k, v in entity_dict.items()}
inv_entity_dict
#{0: 'TMEM35A',
# 1: 'FHL5',
# 2: 'Sirolimus',
# 3: 'TMCO2',
# 4: 'RNF123',
# 5: 'SMURF2',
# 6: 'SSH3',
# 7: 'PSMA4',
# 8: 'SOD3',
# 9: 'SCOC',
# 10: 'Cysteamine',
# 11: 'TOX',
#...}
nonzero[0:10]
#array([ 0, 1, 3, 4, 5, 6, 7, 8, 9, 11])
运行代码后,它返回了意外的结果,因为西罗莫司(idx==2)不在非零数组中,不应该与名称匹配。
for i in range(1):
raw_probs = (z[i][nonzero[0:10]] @ z[i][nonzero[0:10]].t()).sigmoid()
filtered_probs = pd.DataFrame((raw_probs>0.9).nonzero(as_tuple=False).cpu().numpy(), columns=['Gene1', 'Gene2'])
filtered_probs['prob'] = raw_probs[(raw_probs>0.9)].cpu().detach().numpy()
filtered_probs_name = map_id2gene(filtered_probs, inv_entity_dict) #converting func.
#Expected result
# Gene1 Gene2 prob
#67 TOX TOX 1.0
#0 TMEM35A TMEM35A 1.0
#1 TMEM35A FHL5 1.0
#2 TMEM35A RNF123 1.0
#52 SCOC TMEM35A 1.0
#Wrong
# Gene1 Gene2 prob
#67 SCOC SCOC 1.0
#0 TMEM35A TMEM35A 1.0
#1 TMEM35A FHL5 1.0
#2 TMEM35A Sirolimus 1.0
#52 SOD3 TMEM35A 1.0
我猜初始化的
raw_probs
索引直接进入了转换过程。
raw_prob
#tensor([[1.0000e+00, ..., 1.0000e+00], #real index: 0
# [1.0000e+00, ..., 1.0000e+00], #real index: 1
# [1.0000e+00, ..., 1.0000e+00], #real index: 3, but considered to 2
# [1.0000e+00, ..., 1.0000e+00], #real index: 4, but considered to 3, ...
# [1.0000e+00, ..., 1.0000e+00], #real index: 5
# [1.0000e+00, ..., 1.0000e+00], #real index: 6
# [0.0000e+00, ..., 0.0000e+00], #real index: 7
# [0.0000e+00, ..., 4.4097e-36], #real index: 8
# [1.0000e+00, ..., 1.0000e+00], #real index: 9
# [1.0000e+00, ..., 1.0000e+00] #real index: 11, but considered to 9], device='cuda:0')
在这种情况下,如何根据
inv_entity_dict
和 nonzero
列表匹配正确的 id 和名称?
好吧,我只是像这样添加新词典:
d = dict(enumerate(map(int, nonzero)))
filtered_probs['Gene1'] = filtered_probs.apply(lambda L: d[L.iloc[0]], axis=1)
filtered_probs['Gene2'] = filtered_probs.apply(lambda L: d[L.iloc[1]], axis=1)
然后返回正确的结果。
如果您有更好的意见,欢迎评论!