为什么PyG中的LSTM聚合需要对edge_index进行排序?

问题描述 投票:0回答:1

您好,我使用了GraphSAGE来进行节点嵌入。我选择用于聚合的函数是LSTM以及用于图神经网络的PyG库,它需要的参数如下:

输入 1:节点特征 (|V|, F_in) - 这里我使用 2D 平面中的节点坐标 x-y (V x 2),并且已经标准化为 [0, 1] 范围,例如

          x         y
0  0.374540  0.598658
1  0.950714  0.156019
2  0.731994  0.155995

输入 2:边索引 (2, |E|) - 邻接矩阵 (V x V),但仅从我拥有的原始邻接矩阵中检索边到 (2, |E|)

idx   0  1  2
0   [[0, 1, 1], 
1    [1, 0, 1], 
2    [1, 1, 0]]

上图中我们有一个具有 6 条边的形状 (V x V)。我们必须对其进行一些改造以适应 PyG 对形状 (2, |E|) 的使用,我想将其称为

edge_index
,其中边为 (0, 1), (0, 2), (1, 0 ), (1, 2), (2, 0), (2, 1)。

idx   0  1  2  3  4  5
0   [[0, 0, 1, 1, 2, 2],
1    [1, 2, 0, 2, 0, 1]]

输出:节点特征(|V|,F_out) - 与节点坐标类似,但它们不再是二维的,它们位于具有 F_out 维度的新嵌入维度中。

我的问题是,当使用 LSTM 聚合器时,它被迫排序

edge_index
(input2 中的边缘索引),否则会显示错误
ValueError: Can not perform aggregation since the 'index' tensor. is not sorted.

所以我必须做排序用以下命令给出它:

# inside def __init__()
self.graph_sage=SAGEConv(in_channels=2, out_channels=hidden_dim, aggr='lstm')

# inside def forward()
sorted_edge_index, _ = torch.sort(edge_index, dim=1)  # for LSTM
x = self.graph_sage(coord.view(-1, 2), sorted_edge_index)  # using GraphSAGE

排序后

sorted_edge_index
张量将如下所示。

idx   0  1  2  3  4  5
0   [[0, 0, 1, 1, 2, 2],
1    [0, 0, 1, 1, 2, 2]]

我注意到,在连接 3 个节点的全网格图中,当我对它进行排序时,边可以被重新解释为 (0, 0)、(0, 0)、(1, 1)、(1, 1)、 (2, 2), (2, 2) 这让我很好奇。我的问题是以下两件事。

  1. 为什么LSTM需要对
    edge_index
    进行排序?
  2. 像这样对
    edge_index
    进行排序后,我的模型将如何知道哪些节点已连接?因为所有原来的边关系对都没有了。这就像发送图中不存在的边对作为输入。这会是一个缺点吗?

我已经尝试过执行上述操作,并且运行良好。但我有一些疑问,希望有知识的人能够帮助像我这样的初学者澄清问题。我真诚地希望这个问题对其他学习 GNN 的学生也有用。

python pytorch neural-network lstm pytorch-geometric
1个回答
0
投票

按行排序 = False

from torch_geometric.utils import sort_edge_index

sort_edge_index=sort_edge_index(edge_index, num_nodes=self.num_nodes, sort_by_row=False)
x=self.graphsage(coord.view(-1, 2), sort_edge_index)

https://github.com/pyg-team/pytorch_geometric/discussions/8908

© www.soinside.com 2019 - 2024. All rights reserved.