带梯度的 Pytorch bincount

问题描述 投票:0回答:3

我正在尝试使用 bincount 从数组的某些索引的总和中获取梯度。但是,pytorch 没有实现渐变。这可以通过循环和 torch.sum 来实现,但它太慢了。是否有可能在 pytorch 中有效地执行此操作(可能是 einsum 或 index_add)?当然,我们可以遍历索引并逐个添加,但这会显着增加计算图的大小并且性能非常低。

import torch
from torch import autograd
import numpy as np
tt = lambda x, grad=True: torch.tensor(x, requires_grad=grad)    
inds = tt([1, 5, 7, 1], False).long()
y = tt(np.arange(4) + 0.1).float()
sum_y_section = torch.bincount(inds, y * y, minlength=8)
#sum_y_section = torch.sum(y * y)
grad = autograd.grad(sum_y_section, y, create_graph=True, allow_unused=False)
print("sum_y_section", sum_y_section)
print("grad", grad)
pytorch autograd
3个回答
3
投票

我们可以使用 Pytorch V1.11 中的一个新功能,称为 scatter_reduce。

bincount = lambda inds, arr: torch.scatter_reduce(arr, 0, inds, reduce="sum")

1
投票

我会尝试使用钩子以自定义方式操纵渐变


0
投票

torch.scatter_reduce
在 Pytorch 1.13 中有一个
src
位置参数。这个简单的例子演示了
bincount
scatter_reduce
之间的等价性:

num_bins = 8
bins = torch.zeros(num_bins)

#generate indices and weights
num_indices = 100
indices = torch.randint(num_bins, size=(num_indices,))
weights = torch.rand(num_indices)

# Counting Indices

# with torch.bincount
counts1 = torch.bincount(indices, minlength=num_bins)

# with torch.scatter_reduce
counts2 = bins.scatter_reduce(0, indices, torch.ones(num_indices), reduce = 'sum')
print(counts1)
print(counts2)

# Binning Weights

# with torch.bincount
binned_wts1 =  indices.bincount(weights, minlength=num_bins)

# with torch.scatter_reduce
binned_wts2 = bins.scatter_reduce(0, indices, weights, reduce='sum')

print(binned_wts1)
print(binned_wts2)

© www.soinside.com 2019 - 2024. All rights reserved.