如何批量获取PyTorch张量的直方图?

问题描述 投票:0回答:3

有没有办法批量获取torch张量的直方图?

例如: x 是形状为

(64, 224, 224)

的张量
# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)
python pytorch histogram tensor
3个回答
2
投票

可以使用

torch.nn.functional.one_hot
在一行代码中完成此操作:

torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)

基本原理是

one_hot
确实尊重批次,并且对于给定张量的最后一个维度中的每个值 v,创建一个用 0 填充的张量,但第 v 个分量除外,即 1。我们将它们相加对所有此类 one-hot 编码进行遍历,以获得 v 在倒数第二个维度(即
tensor_data
中的最后一个维度)的每行数据中出现的次数。

此方法的一个可能严重的缺点是内存使用,因为每个值都会扩展为大小为

num_classes
的张量(因此,
tensor_data
的大小乘以
num_classes
)。然而,这种内存使用是暂时的,因为
sum
再次折叠了这个额外的维度,并且结果通常会小于
tensor_data
。我说“通常”是因为如果
num_classes
远大于
tensor_data
最后一个维度的大小,那么结果也会相应更大。

这是带有文档的代码,后面是 pytest 测试:

def batch_histogram(data_tensor, num_classes=-1):
    """
    Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
    Arguments:
        data_tensor: a D1 x ... x D_n torch.LongTensor
        num_classes (optional): the number of classes present in data.
                                If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
    Returns:
        A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
        containing histograms of the last dimension D_n of tensor,
        that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
    """
    return torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
def test_batch_histogram():
    data = [2, 5, 1, 1]
    expected = [0, 2, 1, 0, 0, 1]
    run_test(data, expected)

    data = [
        [2, 5, 1, 1],
        [3, 0, 3, 1],
    ]
    expected = [
        [0, 2, 1, 0, 0, 1],
        [1, 1, 0, 2, 0, 0],
    ]
    run_test(data, expected)

    data = [
        [[2, 5, 1, 1], [2, 4, 1, 1], ],
        [[3, 0, 3, 1], [2, 3, 1, 1], ],
    ]
    expected = [
        [[0, 2, 1, 0, 0, 1], [0, 2, 1, 0, 1, 0], ],
        [[1, 1, 0, 2, 0, 0], [0, 2, 1, 1, 0, 0], ],
    ]
    run_test(data, expected)


def test_empty_data():
    data = []
    num_classes = 2
    expected = [0, 0]
    run_test(data, expected, num_classes)

    data = [[], []]
    num_classes = 2
    expected = [[0, 0], [0, 0]]
    run_test(data, expected, num_classes)

    data = [[], []]
    run_test(data, expected=None, exception=RuntimeError)  # num_classes not provided for empty data


def run_test(data, expected, num_classes=-1, exception=None):
    data_tensor = torch.tensor(data, dtype=torch.long)

    if exception is None:
        expected_tensor = torch.tensor(expected, dtype=torch.long)
        actual = batch_histogram(data_tensor, num_classes)
        assert torch.equal(actual, expected_tensor)
    else:
        with pytest.raises(exception):
            batch_histogram(data_tensor, num_classes)

0
投票

不确定,但在我看来这是一件很难做到的事情,而且 PyTorch 没有任何开箱即用的东西。

直方图是一种统计操作。它本质上是离散且不可微的。此外,它们本质上不可矢量化。因此,我认为没有比基于普通循环的解决方案更简单的方法了。

X = torch.rand(64, 224, 224)
h = torch.cat([torch.histc(x, bins=256, min=0, max=255) for x in X], 0)

如果大家有更好的解决方案,欢迎留言。


0
投票

按照 Pytorch Issues#99719 中的建议,您可以通过

torch.Tensor.scatter_add_
来完成此操作。
scatter_add_
torch.nn.functional.one_hot
更高效。

类似于@user118967的回答:

# https://github.com/pytorch/pytorch/issues/99719#issuecomment-1664135524 
def batch_histogram(data_tensor, num_classes=-1):
    """
    Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
    Arguments:
        data_tensor: a D1 x ... x D_n torch.LongTensor
        num_classes (optional): the number of classes present in data.
                                If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
    Returns:
        A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
        containing histograms of the last dimension D_n of tensor,
        that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
    """
    nc = (data_tensor.max()+1) if num_classes <= 0 else num_classes
    hist = torch.zeros((*data_tensor.shape[:-1], nc), dtype=data_tensor.dtype)
    ones = torch.tensor(1, dtype=hist.dtype).expand(data_tensor.shape)
    hist.scatter_add_(-1, data_tensor.long(), ones)
    return hist

使用 Google colab 中的测试用例

© www.soinside.com 2019 - 2024. All rights reserved.