根据标签通过串联对 PyTorch 特征张量进行分组

问题描述 投票:0回答:1

我正在开发一个可批处理、无循环和递归的 PyTorch 实用程序

concat_aggregate
,用于根据
x
张量给出的标签对输入张量
 index
的行进行分组。它应该填充行,以便生成的张量是矩形的。例如,

x = torch.tensor([[5, 50], [6, 60], [7, 70], [8, 80], [9, 90], [10, 100], [11, 110], [12, 120]])
index = torch.tensor([3, 3, 1, 1, 1, 2, 3, 3])
concat_aggregate(x, index)

应该输出:

torch.tensor([
    [[0, 0], [0, 0], [0, 0], [0, 0]],
    [[7, 70], [8, 80], [9, 90], [0, 0]],
    [[10, 100], [0, 0], [0, 0], [0, 0]],
    [[5, 50], [6, 60], [11, 110], [12, 120]]
])

我破解了这个功能:

def cat_aggregate(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    # Number of groups and the number of features in each row of x
    num_groups = index.max().item() + 1
    num_features = x.size(1)
    # Compute the maximum number of elements in any group
    group_sizes = torch.zeros(num_groups, dtype=torch.long, device=x.device)
    group_sizes.index_add_(0, index, torch.ones_like(index, dtype=torch.long))
    # Prepare the output tensor, padded with zeros
    max_num_elements = group_sizes.max()
    result = torch.zeros(num_groups, max_num_elements, num_features, dtype=x.dtype, device=x.device)
    # Positions to fill in the result tensor
    positions = group_sizes.clone().fill_(0)  # Current fill position in each group
    # Fill the tensor
    for i in range(x.size(0)):
        group_id = index[i]
        result[group_id, positions[group_id]] = x[i]
        positions[group_id] += 1
    return result

返回一维和二维张量的正确结果。但是,它需要迭代

x.size(0)
,使其至少在
x
的长度上呈线性。我不确定我所拥有的是否是惯用的。这里有人看到任何可能的效率/复杂性改进或将其扩展到二维张量的明显方法吗?我很惊讶 PyTorch API 中缺少这样的函数。

pytorch tensor pytorch-geometric
1个回答
0
投票

这应该等同于不使用 for 循环的函数

def cat_aggregate(x, index):
    index_count = torch.bincount(index)
    fill_count = index_count.max() - index_count
    fill_zeros = torch.zeros_like(x[0]).repeat(fill_count.sum(),1)
    fill_index = torch.range(0, fill_count.shape[0]-1).repeat_interleave(fill_count)
    index_ = torch.cat([index, fill_index], dim = 0)
    x_ = torch.cat([x, fill_zeros], dim = 0)
    x_ = x_[torch.argsort(index_)].view(index_count.shape[0], index_count.max(), -1)
    return x_

输出:

tensor([[[  0,   0],
         [  0,   0],
         [  0,   0],
         [  0,   0]],

        [[  7,  70],
         [  8,  80],
         [  9,  90],
         [  0,   0]],

        [[ 10, 100],
         [  0,   0],
         [  0,   0],
         [  0,   0]],

        [[  5,  50],
         [  6,  60],
         [ 11, 110],
         [ 12, 120]]])
© www.soinside.com 2019 - 2024. All rights reserved.