dgl的输出无法对应原始节点id

问题描述 投票:0回答:1

我正在运行一个程序来构造异构图,然后使用 NetworkX 计算指向节点的网络指标。但是,在输出 Excel 文件中,指向节点的列(COLUME 's')包含与其原始 ID 号不匹配的重复值。

import pandas as pd
import dgl
import torch as th
import torch.nn.functional as F
import torch.nn as nn
from dgl.nn import GraphConv
from sklearn.preprocessing import LabelEncoder
import numpy as np
import networkx as nx

# Load data from Excel file

# Create a LabelEncoder object
label_encoder = LabelEncoder()

# Extract data from different sheets
...

# Initialize result_df here
...
# Walk through the data and collect node ids before creating the graph
node_set = set()
for index, row in data.iterrows():
node_set.add(row['s'])
if 'sh' in row:
node_set.add(row['sh'])
if 'b' in row:
node_set.add(row['b'])
if 'd' in row:
node_set.add(row['d'])

num_nodes = len(node_set)

# Convert data to int32 type
...

# Create d-s,sh-s,b-s diagrams in turn
...

# Use coded node ids when creating heterogeneous graphs
def create_hetero_graph(e_data, l_data, c_data):

src_sh = e_data['sh_id'].values.astype(np.int64)
dst_scode_sh = e_data['s_id'].values.astype(np.int64)

src_b = l_data['b_id'].values.astype(np.int64)
dst_s_b = l_data['s_id'].values.astype(np.int64)

src_d = c_data['d_id'].values.astype(np.int64)
dst_sc_d = c_data['s_id'].values.astype(np.int64)

edges = {
('sh', 'sh_to_s', 's'): (src_sh, dst_s_sh),
('b', 'b_to_s', 's'): (src_b, dst_s_b),
('d', 'd_to_s', 's'): (src_d, dst_s_d),
}

num_nodes_dict = {
'sh': src_sh.max() + 1,
'b': src_b.max() + 1,
'd': src_d.max() + 1,
's': max(dst_sc_sh.max(), dst_s_b.max(), dst_s_d.max()) + 1
}


g = dgl.heterograph(edges, num_nodes_dict=num_nodes_dict)


scode_encoder = LabelEncoder()
scode_encoder.fit(e_data['s'].astype(str))


decoded_s_sh = s_encoder.inverse_transform(dst_s_sh)
decoded_s_b = s_encoder.inverse_transform(dst_s_b)
decoded_s_d = s_encoder.inverse_transform(dst_s_d)


return g

hetero_graph = create_hetero_graph(e_data, l_data, c_data)

nx_g = nx.Graph(hetero_graph.to_networkx().edges())

# Continue the network analysis operation

# Calculate the degree centrality of the node
...


# Adds the centrality indicator to the result data box
result_df = pd.concat([result_df, pd.DataFrame(s_centralities, index=[0])], ignore_index=True)

result_df.reset_index(drop=True, inplace=True)


unique_labels = label_encoder.classes_

# Replace unknown tags with the most common category
most_common_label = result_df['s_id'].mode()[0] # Gets the most common category
result_df['s_id'] = result_df['s_id'].apply(lambda x: x if x in unique_labels else most_common_label)

# Process the nan value in the resulting data box
result_df['s_id'] = result_df['s_id'].fillna(most_common_label) # Fill nan values with the most common categories

# Reverse conversion tag
result_df['s'] = label_encoder.inverse_transform(result_df['s_id'].astype(int))

# Number of nodes
num_nodes = len(nx_g.nodes())

# Build a Series object for node centrality indicators
degree_centrality_series = pd.Series([degree_centrality.get(node, 0) for node in range(num_nodes)])


# Adds the node centrality indicator to the result data box
result_df['degree_centrality'] = degree_centrality_series


# Save data frames as Excel files (including centrality metrics)
....

我已经调试了几次代码,但无法获得满意的输出。我希望收到启发性的想法,并进一步深入了解异构图的中心性。

python networkx heterogeneous gnn dgl
1个回答
0
投票

这可能是由于 torch.int64 和 int 类型之间的差异造成的

© www.soinside.com 2019 - 2024. All rights reserved.