我正在使用 numba 在 python 中编写一个函数来标记 2D 或 3D 数组中的对象,这意味着输入数组中具有相同值的所有正交连接的单元将在输出数组中被赋予从 1 到 N 的唯一标签,其中 N是正交连接组的数量。它与
scipy.ndimage.label
等函数以及 scikit-image 等库中的类似函数非常相似,但这些函数标记所有正交连接的非零单元格组,因此它将合并具有不同值的连接组,我不想要。例如,给定以下输入:
[0 0 7 7 0 0
0 0 7 0 0 0
0 0 0 0 0 7
0 6 6 0 0 7
0 0 4 4 0 0]
scipy 函数将返回
[0 0 1 1 0 0
0 0 1 0 0 0
0 0 0 0 0 3
0 2 2 0 0 3
0 0 2 2 0 0]
请注意,6 和 4 已合并到标签
2
中。我希望将它们标记为单独的组,例如:
[0 0 1 1 0 0
0 0 1 0 0 0
0 0 0 0 0 4
0 2 2 0 0 4
0 0 3 3 0 0]
我大约一年前问过这个问题并且一直在使用已接受的答案中的解决方案,但是我正在努力优化代码的运行时并重新审视这个问题。
对于我通常使用的数据大小,链接的解决方案需要大约 1 分 30 秒才能运行。我编写了以下递归算法,该算法作为常规 python 运行大约需要 30 秒,而 numba 的 JIT 在 1-2 秒内运行(旁注,我讨厌那个相邻的函数,任何让它不那么混乱同时仍然与 numba 兼容的提示将不胜感激):
@numba.njit
def adjacent(idx, shape):
coords = []
if len(shape) > 2:
if idx[0] < shape[0] - 1:
coords.append((idx[0] + 1, idx[1], idx[2]))
if idx[0] > 0:
coords.append((idx[0] - 1, idx[1], idx[2]))
if idx[1] < shape[1] - 1:
coords.append((idx[0], idx[1] + 1, idx[2]))
if idx[1] > 0:
coords.append((idx[0], idx[1] - 1, idx[2]))
if idx[2] < shape[2] - 1:
coords.append((idx[0], idx[1], idx[2] + 1))
if idx[2] > 0:
coords.append((idx[0], idx[1], idx[2] - 1))
else:
if idx[0] < shape[0] - 1:
coords.append((idx[0] + 1, idx[1]))
if idx[0] > 0:
coords.append((idx[0] - 1, idx[1]))
if idx[1] < shape[1] - 1:
coords.append((idx[0], idx[1] + 1))
if idx[1] > 0:
coords.append((idx[0], idx[1] - 1))
return coords
@numba.njit
def apply_label(labels, decoded_image, current_label, idx):
labels[idx] = current_label
for aidx in adjacent(idx, labels.shape):
if decoded_image[aidx] == decoded_image[idx] and labels[aidx] == 0:
apply_label(labels, decoded_image, current_label, aidx)
@numba.njit
def label_image(decoded_image):
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image >= 0)):
if labels[idx] == 0:
current_label += 1
apply_label(labels, decoded_image, current_label, idx)
return labels, current_label
这适用于某些数据,但在其他数据上崩溃,我发现问题是当有非常大的对象要标记时,就会达到递归限制。我尝试重写
label_image
以不使用递归,但现在使用 numba 需要大约 10 秒。与我开始的地方相比仍然有很大的改进,但似乎应该可以获得与递归版本相同的性能。这是我的迭代版本:
@numba.njit
def label_image(decoded_image):
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image >= 0)):
if labels[idx] == 0:
current_label += 1
idxs = [idx]
while idxs:
cidx = idxs.pop()
if labels[cidx] == 0:
labels[cidx] = current_label
for aidx in adjacent(cidx, labels.shape):
if labels[aidx] == 0 and decoded_image[aidx] == decoded_image[idx]:
idxs.append(aidx)
return labels, current_label
有什么办法可以改善这个吗?
这个递归函数能否变成性能类似的迭代函数?
将其转换为迭代函数很简单,考虑到它只是一个简单的深度优先搜索(您也可以使用此处使用队列而不是堆栈的广度优先搜索,两者都可以)。只需使用堆栈来跟踪要访问的节点即可。这是适用于任意数量维度的通用解决方案:
def label_image(decoded_image):
shape = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
top = stack.pop()
labels[top] = current_label
for i in range(0, len(shape)):
if top[i] > 0:
neighbor = list(top)
neighbor[i] -= 1
neighbor = tuple(neighbor)
if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
stack.append(neighbor)
if top[i] < shape[i] - 1:
neighbor = list(top)
neighbor[i] += 1
neighbor = tuple(neighbor)
if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
stack.append(neighbor)
return labels
从元组的第 i 个分量中添加或减去一个是很尴尬的(我将在此处查看临时列表)并且 numba 不接受它(类型错误)。一种简单的解决方案是显式编写 2d 和 3d 版本,这可能会极大地提高性能:
@numba.njit
def label_image_2d(decoded_image):
w, h = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
x, y = stack.pop()
if decoded_image[x, y] != decoded_image[idx] or labels[x, y] != 0:
continue # already visited or not part of this group
labels[x, y] = current_label
if x > 0: stack.append((x-1, y))
if x+1 < w: stack.append((x+1, y))
if y > 0: stack.append((x, y-1))
if y+1 < h: stack.append((x, y+1))
return labels
@numba.njit
def label_image_3d(decoded_image):
w, h, l = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
x, y, z = stack.pop()
if decoded_image[x, y, z] != decoded_image[idx] or labels[x, y, z] != 0:
continue # already visited or not part of this group
labels[x, y, z] = current_label
if x > 0: stack.append((x-1, y, z))
if x+1 < w: stack.append((x+1, y, z))
if y > 0: stack.append((x, y-1, z))
if y+1 < h: stack.append((x, y+1, z))
if z > 0: stack.append((x, y, z-1))
if z+1 < l: stack.append((x, y, z+1))
return labels
def label_image(decoded_image):
dim = len(decoded_image.shape)
if dim == 2:
return label_image_2d(decoded_image)
assert dim == 3
return label_image_3d(decoded_image)
另请注意,迭代解决方案不受堆栈限制的影响:
np.full((100,100,100), 1)
在迭代解决方案中工作得很好,但在递归解决方案中失败(如果使用numba,则会出现段错误)。
做一个非常基本的基准测试
for i in range(1, 10000):
label_image(np.full((20,20,20), i))
(多次迭代以尽量减少 JIT 的影响,也可以进行一些预热运行,然后开始测量时间或类似的)
迭代解决方案似乎快了几倍(在我的机器上大约是 5 倍)。您可能可以优化递归解决方案并使其达到可比较的速度,例如通过避免临时
coords
列表或将 np.where
更改为 > 0
。
我不知道 numba 能够如何优化压缩的
np.where
。为了进一步优化,您可以考虑(和基准测试)使用显式嵌套 for x in range(0, w): for y in range(0, h):
循环。