我正在 PyTorch 中工作,我试图找到每个唯一行第一次出现的索引。这是我到目前为止所拥有的:
import torch
import numpy as np
# Example tensor
data = torch.rand(100, 5)
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
# Finding unique rows
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)
# Finding the first occurrence index of each unique row
unique_indices = torch.zeros(len(u_data), dtype=torch.long)
for idx in range(len(u_data)):
unique_indices[idx] = torch.where(inverse_indices == idx)[0][0]
此代码可以工作,但有一个低效的“unique_indices”循环。有没有更高效、类似 Pytorch 的方法来获取 unique_indices?
避免循环的一种方法是使用 2D 张量。我将使用值“1000”作为任何大于原始数据张量中行数的值的占位符,即
1000 > 100
。
因此,我将创建一个 2D 张量
A
,其维度是原始行数除以唯一行数。我将其初始化为 1000
,这意味着“未定义的行索引”。然后,对于每个原始行索引 i
,我将设置 A[i] = inverse_indices[i]
。完成后,ith
的A
行将主要是1000
,除了索引是行i
现在所在的唯一行的索引的列。
现在,让我们按列考虑这个张量
A
。 jth
列是 jth
唯一行。它的值同样主要是 1000
,除了一些包含 j
的行。该列的 argmin
正是映射到唯一行 j
的第一个原始行的索引。
这是基于您自己的代码的代码:
import torch
import numpy as np
# Example tensor
data = torch.rand(100, 5)
data[np.random.choice(100, 50, replace=False)] = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
# Finding unique rows
u_data, inverse_indices, counts = torch.unique(data, dim=0, return_inverse=True, return_counts=True)
# Finding the first occurrence index of each unique row
unique_indices = torch.zeros(len(u_data), dtype=torch.long)
for idx in range(len(u_data)):
unique_indices[idx] = torch.where(inverse_indices == idx)[0][0]
# My suggestion:
A = 1000 * torch.ones((len(data), len(u_data)), dtype=torch.long)
A[torch.arange(len(data)), inverse_indices] = inverse_indices
unique_indices2 = torch.argmin(A, dim=0)
# Verify my method matches yours
print(torch.allclose(unique_indices2,unique_indices))