在python中高效查找邻居

问题描述 投票:0回答:1

我有一个三维网格。对于网格上的每个点,我想找到它最近的邻居。由于我的网格是均匀采样的,我只想收集最接近的邻居。

示例网格:

所需邻居:

对于给定点 X,我需要以下邻居:

我目前的工作代码:

import numpy as np
import cProfile

class Neighbours:
    # Get neighbors
    @classmethod
    def get_neighbour_indices(cls, row, col, frame, distance=1):                
        # Define the indices for the neighbor pixels    
        r = np.linspace(row - distance, row + distance, 2 * distance + 1)
        c = np.linspace(col - distance, col + distance, 2 * distance + 1)
        f = np.linspace(frame - distance, frame + distance, 2 * distance + 1)
        nc, nr, nf = np.meshgrid(c, r, f)
        neighbors = np.vstack((nr.flatten(), nc.flatten(), nf.flatten())).T
    
        # Filter out valid neighbor indices within the array bounds
        valid_indices = (neighbors[ :, 0] >= 0) & (neighbors[ :, 0] < nRows) & (neighbors[ :, 1] >= 0) & (neighbors[ :, 1] < nCols) & (neighbors[ :, 2] >= 0) & (neighbors[ :, 2] < nFrames) 

        # Return the valid neighbor indices
        valid_neighbors = neighbors[valid_indices]
        return valid_neighbors

    @classmethod
    def MapIndexVsNeighbours(cls):
        neighbours_info = np.empty((nRows * nCols * nFrames), dtype=object)
        for frame in range(nFrames):
            for row in range(nRows):
                for col in range(nCols):        
                    neighbour_indices = cls.get_neighbour_indices(row, col, frame, distance=1)        
                    flat_idx = frame * (nRows * nCols) + (row * nCols + col)
                    neighbours_info[flat_idx] = neighbour_indices                            

        return neighbours_info


########################------------------main()-------##################
####--run
if __name__ == "__main__":  
    nRows = 151
    nCols = 151
    nFrames = 24
    cProfile.run('Neighbours.MapIndexVsNeighbours()', sort='cumulative')  
    print() 

问题:对于较大的网格(例如201 x 201 x 24),程序需要很长时间。在使用

cProfile
的分析结果中,我可以看到
meshgrid()
中的
get_neighbour_indices()
花费了相当长的时间。总而言之,这不是一个有效的实现。此外,我尝试在单独的线程上执行
MapIndexVsNeighbours()
,但由于GIL锁,它并没有真正并行执行。因此,可以并行执行的东西将是理想的实现。

python performance parallel-processing cpython nearest-neighbor
1个回答
0
投票

您可以使用 来加速计算:

from timeit import timeit

from numba import njit


@njit
def get_arr(x, y, f, w, h, distance, frames):
    _x_from = max(0, x - distance)
    _x_to = min(w - 1, x + distance)

    _y_from = max(0, y - distance)
    _y_to = min(h - 1, y + distance)

    _z_from = max(0, f - distance)
    _z_to = min(frames - 1, f + distance)

    out = []
    for _x in range(_x_from, _x_to + 1):
        for _y in range(_y_from, _y_to + 1):
            for _z in range(_z_from, _z_to + 1):
                out.append([_x, _y, _z])
    return np.array(out)


def MapIndexVsNeighbours_numba(nRows, nCols, nFrames):
    neighbours_info = np.empty((nRows * nCols * nFrames), dtype=object)
    for frame in range(nFrames):
        for row in range(nRows):
            for col in range(nCols):
                neighbour_indices = get_arr(row, col, frame, nRows, nCols, 1, nFrames)
                flat_idx = frame * (nRows * nCols) + (row * nCols + col)
                neighbours_info[flat_idx] = neighbour_indices
    return neighbours_info


nRows = 151
nCols = 151
nFrames = 24

v1 = MapIndexVsNeighbours_numba(nRows, nCols, nFrames)
v2 = Neighbours.MapIndexVsNeighbours()
assert all(np.allclose(a, b) for a, b in zip(v1, v2))

t_numba = timeit(
    "MapIndexVsNeighbours_numba(nRows, nCols, nFrames)", number=1, globals=globals()
)
t_original = timeit("Neighbours.MapIndexVsNeighbours()", number=1, globals=globals())

print(f"{t_numba=}")
print(f"{t_original=}")

在我的机器 AMD 5700x 上打印:

t_numba=1.2360494260210544
t_original=31.86005672905594

  • 使用
    201 x 201 x 24
    numba 功能采取了
    2.5117316420655698
  • 使用
    1024 x 1024 x 24
    numba 功能采取了
    62.206267355941236
© www.soinside.com 2019 - 2024. All rights reserved.