在 PyTorch 张量中查找唯一行索引的有效方法

问题描述 投票:0回答:1

我正在 PyTorch 中工作,我试图找到每个唯一行第一次出现的索引。这是我到目前为止所拥有的:

import torch
import numpy as np

# Example tensor
data = torch.rand(100, 5)
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

# Finding unique rows
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)

# Finding the first occurrence index of each unique row
unique_indices = torch.zeros(len(u_data), dtype=torch.long)
for idx in range(len(u_data)):
    unique_indices[idx] = torch.where(inverse_indices == idx)[0][0]

此代码可以工作,但有一个低效的“unique_indices”循环。有没有更高效、类似 Pytorch 的方法来获取 unique_indices?

python machine-learning pytorch tensor
1个回答
0
投票

避免循环的一种方法是使用 2D 张量。我将使用值“1000”作为任何大于原始数据张量中行数的值的占位符,即

1000 > 100

因此,我将创建一个 2D 张量

A
,其维度是原始行数除以唯一行数。我将其初始化为
1000
,这意味着“未定义的行索引”。然后,对于每个原始行索引
i
,我将设置
A[i] = inverse_indices[i]
。完成后,
ith
A
行将主要是
1000
,除了索引是行
i
现在所在的唯一行的索引的列。

现在,让我们按列考虑这个张量

A
jth
列是
jth
唯一行。它的值同样主要是
1000
,除了一些包含
j
的行。该列的
argmin
正是映射到唯一行
j
的第一个原始行的索引。

这是基于您自己的代码的代码:

import torch
import numpy as np

# Example tensor
data = torch.rand(100, 5)
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

# Finding unique rows
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)

# Finding the first occurrence index of each unique row
unique_indices = torch.zeros(len(u_data), dtype=torch.long)
for idx in range(len(u_data)):
    unique_indices[idx] = torch.where(inverse_indices == idx)[0][0]

# My suggestion:

A = 1000 * torch.ones((len(data), len(u_data)), dtype=torch.long)
A[torch.arange(len(data)), inverse_indices] = inverse_indices
unique_indices2 = torch.argmin(A, dim=0)

# Verify my method matches yours
print(torch.allclose(unique_indices2,unique_indices))
© www.soinside.com 2019 - 2024. All rights reserved.