我正在开发一个可批处理、无循环和递归的 PyTorch 实用程序
concat_aggregate
,用于根据 x
张量给出的标签对输入张量 index
的行进行分组。它应该填充行,以便生成的张量是矩形的。例如,
x = torch.tensor([[5, 50], [6, 60], [7, 70], [8, 80], [9, 90], [10, 100], [11, 110], [12, 120]])
index = torch.tensor([3, 3, 1, 1, 1, 2, 3, 3])
concat_aggregate(x, index)
应该输出:
torch.tensor([
[[0, 0], [0, 0], [0, 0], [0, 0]],
[[7, 70], [8, 80], [9, 90], [0, 0]],
[[10, 100], [0, 0], [0, 0], [0, 0]],
[[5, 50], [6, 60], [11, 110], [12, 120]]
])
我破解了这个功能:
def cat_aggregate(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
# Number of groups and the number of features in each row of x
num_groups = index.max().item() + 1
num_features = x.size(1)
# Compute the maximum number of elements in any group
group_sizes = torch.zeros(num_groups, dtype=torch.long, device=x.device)
group_sizes.index_add_(0, index, torch.ones_like(index, dtype=torch.long))
# Prepare the output tensor, padded with zeros
max_num_elements = group_sizes.max()
result = torch.zeros(num_groups, max_num_elements, num_features, dtype=x.dtype, device=x.device)
# Positions to fill in the result tensor
positions = group_sizes.clone().fill_(0) # Current fill position in each group
# Fill the tensor
for i in range(x.size(0)):
group_id = index[i]
result[group_id, positions[group_id]] = x[i]
positions[group_id] += 1
return result
返回一维和二维张量的正确结果。但是,它需要迭代
x.size(0)
,使其至少在 x
的长度上呈线性。我不确定我所拥有的是否是惯用的。这里有人看到任何可能的效率/复杂性改进或将其扩展到二维张量的明显方法吗?我很惊讶 PyTorch API 中缺少这样的函数。
这应该等同于不使用 for 循环的函数
def cat_aggregate(x, index):
index_count = torch.bincount(index)
fill_count = index_count.max() - index_count
fill_zeros = torch.zeros_like(x[0]).repeat(fill_count.sum(),1)
fill_index = torch.range(0, fill_count.shape[0]-1).repeat_interleave(fill_count)
index_ = torch.cat([index, fill_index], dim = 0)
x_ = torch.cat([x, fill_zeros], dim = 0)
x_ = x_[torch.argsort(index_)].view(index_count.shape[0], index_count.max(), -1)
return x_
输出:
tensor([[[ 0, 0],
[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[ 7, 70],
[ 8, 80],
[ 9, 90],
[ 0, 0]],
[[ 10, 100],
[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[ 5, 50],
[ 6, 60],
[ 11, 110],
[ 12, 120]]])