有没有办法批量获取torch张量的直方图?
例如: x 是形状为
(64, 224, 224)
的张量
# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)
可以使用
torch.nn.functional.one_hot
在一行代码中完成此操作:
torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
基本原理是
one_hot
确实尊重批次,并且对于给定张量的最后一个维度中的每个值 v,创建一个用 0 填充的张量,但第 v 个分量除外,即 1。我们将它们相加对所有此类 one-hot 编码进行遍历,以获得 v 在倒数第二个维度(即 tensor_data
中的最后一个维度)的每行数据中出现的次数。
此方法的一个可能严重的缺点是内存使用,因为每个值都会扩展为大小为
num_classes
的张量(因此,tensor_data
的大小乘以 num_classes
)。然而,这种内存使用是暂时的,因为 sum
再次折叠了这个额外的维度,并且结果通常会小于 tensor_data
。我说“通常”是因为如果 num_classes
远大于 tensor_data
最后一个维度的大小,那么结果也会相应更大。
这是带有文档的代码,后面是 pytest 测试:
def batch_histogram(data_tensor, num_classes=-1):
"""
Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
Arguments:
data_tensor: a D1 x ... x D_n torch.LongTensor
num_classes (optional): the number of classes present in data.
If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
Returns:
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
containing histograms of the last dimension D_n of tensor,
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
"""
return torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
def test_batch_histogram():
data = [2, 5, 1, 1]
expected = [0, 2, 1, 0, 0, 1]
run_test(data, expected)
data = [
[2, 5, 1, 1],
[3, 0, 3, 1],
]
expected = [
[0, 2, 1, 0, 0, 1],
[1, 1, 0, 2, 0, 0],
]
run_test(data, expected)
data = [
[[2, 5, 1, 1], [2, 4, 1, 1], ],
[[3, 0, 3, 1], [2, 3, 1, 1], ],
]
expected = [
[[0, 2, 1, 0, 0, 1], [0, 2, 1, 0, 1, 0], ],
[[1, 1, 0, 2, 0, 0], [0, 2, 1, 1, 0, 0], ],
]
run_test(data, expected)
def test_empty_data():
data = []
num_classes = 2
expected = [0, 0]
run_test(data, expected, num_classes)
data = [[], []]
num_classes = 2
expected = [[0, 0], [0, 0]]
run_test(data, expected, num_classes)
data = [[], []]
run_test(data, expected=None, exception=RuntimeError) # num_classes not provided for empty data
def run_test(data, expected, num_classes=-1, exception=None):
data_tensor = torch.tensor(data, dtype=torch.long)
if exception is None:
expected_tensor = torch.tensor(expected, dtype=torch.long)
actual = batch_histogram(data_tensor, num_classes)
assert torch.equal(actual, expected_tensor)
else:
with pytest.raises(exception):
batch_histogram(data_tensor, num_classes)
不确定,但在我看来这是一件很难做到的事情,而且 PyTorch 没有任何开箱即用的东西。
直方图是一种统计操作。它本质上是离散且不可微的。此外,它们本质上不可矢量化。因此,我认为没有比基于普通循环的解决方案更简单的方法了。
X = torch.rand(64, 224, 224)
h = torch.cat([torch.histc(x, bins=256, min=0, max=255) for x in X], 0)
如果大家有更好的解决方案,欢迎留言。
按照 Pytorch Issues#99719 中的建议,您可以通过
torch.Tensor.scatter_add_
来完成此操作。 scatter_add_
比 torch.nn.functional.one_hot
更高效。
类似于@user118967的回答:
# https://github.com/pytorch/pytorch/issues/99719#issuecomment-1664135524
def batch_histogram(data_tensor, num_classes=-1):
"""
Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
Arguments:
data_tensor: a D1 x ... x D_n torch.LongTensor
num_classes (optional): the number of classes present in data.
If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
Returns:
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
containing histograms of the last dimension D_n of tensor,
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
"""
nc = (data_tensor.max()+1) if num_classes <= 0 else num_classes
hist = torch.zeros((*data_tensor.shape[:-1], nc), dtype=data_tensor.dtype)
ones = torch.tensor(1, dtype=hist.dtype).expand(data_tensor.shape)
hist.scatter_add_(-1, data_tensor.long(), ones)
return hist