如何从 Pytorch-Geometric GAT 模型中提取图形节点嵌入?

问题描述 投票:0回答:1

Dataset Strucute:时间有向图;节点有特性;边缘没有特征;节点被标记。使用椭圆数据集

任务:对节点进行分类/预测节点标签。

数据结构:2个

.csv
节点和边文件。

  • 对于节点 csv
    #Rows = #Nodes
    #Columns = #Features
  • 对于边缘 csv
    #Rows = #Edges
  • 最后这两个文件都转换为张量并变成Pytorch-几何数据类

我想在数据上训练各种图神经网络并从网络中提取节点嵌入。我知道这是可能的,因为 Elliptic 数据集的作者从 GCN 中提取了节点嵌入

下面是我正在使用的 GAT 的代码。

class GAT(torch.nn.Module):
  """Graph Attention Network"""
  def __init__(self, dim_in, dim_h, dim_out, heads=24):
    super().__init__()
    self.gat1 = GATv2Conv(dim_in, dim_h, heads=heads)
    self.gat2 = GATv2Conv(dim_h*heads, dim_out, heads=1)
    self.optimizer = torch.optim.Adam(self.parameters(),
                                      lr=0.25,
                                      weight_decay=5e-4)

  def forward(self, x, edge_index):
    h = F.dropout(x, p=0.5, training=self.training)
    h = self.gat1(x, edge_index)
    h = F.elu(h)
    h = F.dropout(h, p=0.5, training=self.training)
    h = self.gat2(h, edge_index)
    return h, F.log_softmax(h, dim=1)

此函数返回经过训练的模型

def train(model, data , epochs = 200):
    """Train a GNN model and return the trained model."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = model.optimizer

    model = model.to(device)

    model.train()
    for epoch in range(epochs+1):
        # Training
        optimizer.zero_grad()
        _, out = model(data.x.to(device), data.edge_index.to(device))
        loss = criterion(out[data.train_mask].to(device), data.y[data.train_mask].to(device))
        loss.backward()
        optimizer.step()

        # Print metrics every 10 epochs
        if(epoch % 10 == 0):
            print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f}')
          
    return model

我需要对代码进行哪些修改才能提取节点嵌入?

python deep-learning pytorch pytorch-geometric self-attention
1个回答
0
投票

您好,您可以编写一个方法,类似于使用子图加载器来处理大图:

        def representation(self,x_all): for i, conv in enumerate(self.convs): xs = [] for batch in subgraph_loader: x = x_all[batch.n_id.to(x_all.device)].to(device) x = conv(x, batch.edge_index.to(device)) if i < len(self.convs) - 1: x = F.elu_(x) xs.append(x[:batch.batch_size].cpu()) pbar.update(batch.batch_size) x_all = torch.cat(xs, dim=0) pbar.close() return x_all
来自 https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cluster_gcn_reddit.py

如果图形不是很大,您还可以使用 pytorch 几何实用程序中的 get_embeddings :

https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/embedding.html

© www.soinside.com 2019 - 2024. All rights reserved.