我想通过 pytorch geo 获取邻接列表。 pytorch中是否有任何方法可以做到这一点或任何转换邻接矩阵但经过优化而不是循环的方法?
像这样:
import numpy as np
import networkx as nx
def adjacency_matrix_to_list(adj_matrix):
# Créer un graphe à partir de la matrice d'adjacence
graph = nx.from_numpy_matrix(adj_matrix)
# Convertir le graphe en liste d'adjacence
adj_list = nx.to_dict_of_lists(graph)
return adj_list
不使用循环我想你的意思是避免像这样的嵌套循环:
import torch
def adjacency_matrix_to_list_with_loops(adj_matrix):
num_nodes = adj_matrix.shape[0]
adj_list = []
for i in range(num_nodes):
neighbors = []
for j in range(num_nodes):
if adj_matrix[i][j] == 1:
neighbors.append(j)
adj_list.append(neighbors)
return adj_list
如果您不想使用像
networkx
这样的附加库,那么您将需要使用至少一个循环来迭代矩阵节点。
torch.nonzero
的结果是形状为(n, 1)
的张量(其中n
是非零元素的数量)。 .squeeze()
方法删除了单维,产生形状为 (n,)
的张量。
import torch
def adjacency_matrix_to_list(adj_matrix):
num_nodes = adj_matrix.shape[0]
adj_list = []
for i in range(num_nodes):
# Find indices of non-zero elements in the i-th row
neighbors = torch.nonzero(adj_matrix[i]).squeeze()
adj_list.append(neighbors.tolist())
return adj_list
该方法的性能取决于图的大小和邻接矩阵的稀疏性。对于非常大的图,
networkx
解决方案可能是一种更优化的方法。