我有一个三维网格。对于网格上的每个点,我想找到它最近的邻居。由于我的网格是均匀采样的,我只想收集最接近的邻居。
示例网格:
所需邻居:
对于给定点 X,我需要以下邻居:
我目前的工作代码:
import numpy as np
import cProfile
class Neighbours:
# Get neighbors
@classmethod
def get_neighbour_indices(cls, row, col, frame, distance=1):
# Define the indices for the neighbor pixels
r = np.linspace(row - distance, row + distance, 2 * distance + 1)
c = np.linspace(col - distance, col + distance, 2 * distance + 1)
f = np.linspace(frame - distance, frame + distance, 2 * distance + 1)
nc, nr, nf = np.meshgrid(c, r, f)
neighbors = np.vstack((nr.flatten(), nc.flatten(), nf.flatten())).T
# Filter out valid neighbor indices within the array bounds
valid_indices = (neighbors[ :, 0] >= 0) & (neighbors[ :, 0] < nRows) & (neighbors[ :, 1] >= 0) & (neighbors[ :, 1] < nCols) & (neighbors[ :, 2] >= 0) & (neighbors[ :, 2] < nFrames)
# Return the valid neighbor indices
valid_neighbors = neighbors[valid_indices]
return valid_neighbors
@classmethod
def MapIndexVsNeighbours(cls):
neighbours_info = np.empty((nRows * nCols * nFrames), dtype=object)
for frame in range(nFrames):
for row in range(nRows):
for col in range(nCols):
neighbour_indices = cls.get_neighbour_indices(row, col, frame, distance=1)
flat_idx = frame * (nRows * nCols) + (row * nCols + col)
neighbours_info[flat_idx] = neighbour_indices
return neighbours_info
########################------------------main()-------##################
####--run
if __name__ == "__main__":
nRows = 151
nCols = 151
nFrames = 24
cProfile.run('Neighbours.MapIndexVsNeighbours()', sort='cumulative')
print()
问题:对于较大的网格(例如201 x 201 x 24),程序需要很长时间。在使用
cProfile
的分析结果中,我可以看到 meshgrid()
中的 get_neighbour_indices()
花费了相当长的时间。总而言之,这不是一个有效的实现。此外,我尝试在单独的线程上执行MapIndexVsNeighbours()
,但由于GIL锁,它并没有真正并行执行。因此,可以并行执行的东西将是理想的实现。
您可以使用 numba 来加速计算:
from timeit import timeit
from numba import njit
@njit
def get_arr(x, y, f, w, h, distance, frames):
_x_from = max(0, x - distance)
_x_to = min(w - 1, x + distance)
_y_from = max(0, y - distance)
_y_to = min(h - 1, y + distance)
_z_from = max(0, f - distance)
_z_to = min(frames - 1, f + distance)
out = []
for _x in range(_x_from, _x_to + 1):
for _y in range(_y_from, _y_to + 1):
for _z in range(_z_from, _z_to + 1):
out.append([_x, _y, _z])
return np.array(out)
def MapIndexVsNeighbours_numba(nRows, nCols, nFrames):
neighbours_info = np.empty((nRows * nCols * nFrames), dtype=object)
for frame in range(nFrames):
for row in range(nRows):
for col in range(nCols):
neighbour_indices = get_arr(row, col, frame, nRows, nCols, 1, nFrames)
flat_idx = frame * (nRows * nCols) + (row * nCols + col)
neighbours_info[flat_idx] = neighbour_indices
return neighbours_info
nRows = 151
nCols = 151
nFrames = 24
v1 = MapIndexVsNeighbours_numba(nRows, nCols, nFrames)
v2 = Neighbours.MapIndexVsNeighbours()
assert all(np.allclose(a, b) for a, b in zip(v1, v2))
t_numba = timeit(
"MapIndexVsNeighbours_numba(nRows, nCols, nFrames)", number=1, globals=globals()
)
t_original = timeit("Neighbours.MapIndexVsNeighbours()", number=1, globals=globals())
print(f"{t_numba=}")
print(f"{t_original=}")
在我的机器 AMD 5700x 上打印:
t_numba=1.2360494260210544
t_original=31.86005672905594
201 x 201 x 24
numba 功能采取了 2.5117316420655698
1024 x 1024 x 24
numba 功能采取了 62.206267355941236