使用 PyTorch Geometric RandomNodeSplit 获取 train_mask、val_mask、test_mask 时遇到问题

问题描述 投票:0回答:1

我有一些来自 PyTorch Geometric 的

Data
对象,其形式为
Data(x=[621, 2], edge_index=[2, 1142], edge_attr=[1142, 1])

我想拆分数据以应用某些模型。我试过这个代码

import torch_geometric.transforms as T

split = T.RandomNodeSplit(split='random', num_val=0.1, num_test=0.2)
data = split(data)
data

但输出仍然是与之前相同的 Data 对象,没有

train_mask
val_mask
test_mask
。现在有人知道为什么吗?谢谢:)

我见过一些将此方法应用于 PyTorch Geometric 的自定义数据集(如 Planetoid)的示例,但我的数据不是其中之一。

通过文档,我在代码中没有发现任何错误。都没有找到任何其他例子。

pytorch networkx pytorch-geometric
1个回答
0
投票

您需要设置关键参数。默认值为 y,但如果您没有标签,可以将其设置为 key=None

随机节点分割(键=无)

© www.soinside.com 2019 - 2024. All rights reserved.