如何使用任何工具可视化HeteroData pytorch几何图形?

问题描述 投票:0回答:2

您好,可视化 pyg HeteroData 对象的好方法是什么? (定义类似:https://pytorch-geometric.readthedocs.io/en/latest/notes/heterogeneous.html#creating-heterogeneous-gnns

我尝试使用networkx,但我认为它仅限于同质图(可以转换它,但信息量要少得多)。

g = torch_geometric.utils.to_networkx(data.to_homogeneous(), to_undirected=False )

有人尝试用其他 python lib (matplotlib) 或 js (sigma.js/d3.js) 来做到这一点吗?

您可以分享任何文档链接吗?

networkx visualization graph-visualization pytorch-geometric
2个回答
0
投票

您可以使用 networkx 来完成此操作,但您需要进行一些编码来告诉它如何格式化节点和边。

# Simple example of network x rendering with colored nodes and edges
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

graph = to_networkx(data, to_undirected=False)

# Define colors for nodes and edges
node_type_colors = {
    "Station": "#4599C3",
    "Lot": "#ED8546",
}

node_colors = []
labels = {}
for node, attrs in graph.nodes(data=True):
    node_type = attrs["type"]
    color = node_type_colors[node_type]
    node_colors.append(color)
    if attrs["type"] == "Station":
        labels[node] = f"S{node}"
    elif attrs["type"] == "Lot":
        labels[node] = f"L{node}"

# Define colors for the edges
edge_type_colors = {
    ("Lot", "SameSetup", "Station"): "#8B4D9E",
    ("Station", "ShortSetup", "Lot"): "#DFB825",
    ("Lot", "SameEnergySetup", "Station"): "#70B349",
    ("Station", "ProcessNow", "Lot"): "#DB5C64",
}

edge_colors = []
for from_node, to_node, attrs in graph.edges(data=True):
    edge_type = attrs["type"]
    color = edge_type_colors[edge_type]

    graph.edges[from_node, to_node]["color"] = color
    edge_colors.append(color)


# Draw the graph
pos = nx.spring_layout(graph, k=2)
nx.draw_networkx(
    graph,
    pos=pos,
    labels=labels,
    with_labels=True,
    node_color=node_colors,
    edge_color=edge_colors,
    node_size=600,
)
plt.show()

-1
投票

我已完成以下操作:

import networkx as nx
from matplotlib import pyplot as plt
from torch_geometric.nn import to_hetero

g = torch_geometric.utils.to_networkx(data.to_homogeneous())
# Networkx seems to create extra nodes from our heterogeneous graph, so I remove them
isolated_nodes = [node for node in g.nodes() if g.out_degree(node) == 0]
[g.remove_node(i_n) for i_n in isolated_nodes]
# Plot the graph
nx.draw(g, with_labels=True)
plt.show()

但是,它确实被“扁平化”为同质的,而例如,对不同类型的节点使用不同的颜色会更有趣。

© www.soinside.com 2019 - 2024. All rights reserved.