为什么使用pytorch的KNN广播这么慢?

问题描述 投票:0回答:1

我正在尝试寻找网格点的 knn 。这是生成网格的代码

def grid_by(lims=[[0, 1], [0, 1]], size=[28, 28]):
    """
    Creates a tensor of 2D grid points.
    Grid points have one-to-one correspondence with input pixel values that are flattened in row-major order.

    Args:
        lims: [[domain of H], [domain of W]]
        size: [H, W]
    Returns:
        grid: Tensor of shape [(H*W), 2]
    """
    assert len(size) == 2 and len(lims) == len(size)
    expansions = [torch.linspace(start, end, steps) if i != 0 else torch.linspace(end, start, steps) for i, ((start, end), steps) in enumerate(zip(lims, size))]
    grid = torch.index_select(torch.cartesian_prod(*expansions),
                        dim=1,
                        index=torch.tensor([1,0]))
    return grid

我制作了一个自定义的KNN函数,以便它可以在pytorch中的gpu上运行。 假设 L2 距离,pytorch 代码如下所示。

def knn(grid, k):
    """
    Brute Force KNN.

    Args:
            grid: Tensor of shape [(H*W), D]
            k: Int representing number of neighbors 
    """
    d = grid.shape[-1]
    Xr = grid.unsqueeze(1)
    Yr = grid.view(1, -1, d)
    distances = torch.sqrt(torch.sum((Xr - Yr)**2, -1))
    dist, index = distances.topk(k, largest=False, dim=-1)
    return dist, index

grid = grid_by().to('cuda')
knn_dist, knn_index = knn(grid, k=10)

sklearn 代码如下所示

grid = grid_by()
nn = NearestNeighbors(n_jobs=-1)
nn.fit(grid)
knn_dist, knn_index = nn.kneighbors(self.grid, n_neighbors=10)

我使用了广播,所以我希望它在 GPU 上运行得相当快。然而,当我用

timeit
测量运行时间时,它比使用在 cpu 上运行的 sklearn NearestNeighbor 慢得多。速度这么慢是什么原因?

scikit-learn pytorch knn array-broadcasting
1个回答
0
投票

我无法重现。我在 64 核和 3090 GPU 的机器上运行了以下

timeit
测试。

knn
cpu
786 µs ± 74 µs per loop

knn
cuda:0
197 µs ± 437 ns per loop

sklearn
cpu
n_jobs=-1
21.6 ms ± 212 µs per loop

sklearn
cpu
n_jobs=None
1.16 ms ± 299 ns per loop

© www.soinside.com 2019 - 2024. All rights reserved.