我有一些来自 PyTorch Geometric 的
Data
对象,其形式为 Data(x=[621, 2], edge_index=[2, 1142], edge_attr=[1142, 1])
。
我想拆分数据以应用某些模型。我试过这个代码
import torch_geometric.transforms as T
split = T.RandomNodeSplit(split='random', num_val=0.1, num_test=0.2)
data = split(data)
data
但输出仍然是与之前相同的 Data 对象,没有
train_mask
、val_mask
、test_mask
。现在有人知道为什么吗?谢谢:)
我见过一些将此方法应用于 PyTorch Geometric 的自定义数据集(如 Planetoid)的示例,但我的数据不是其中之一。
通过文档,我在代码中没有发现任何错误。都没有找到任何其他例子。
您需要设置关键参数。默认值为 y,但如果您没有标签,可以将其设置为 key=None
随机节点分割(键=无)