我已经发现这个功能是我的瓶颈,我需要提高它的效率。
所以,为了加快速度,我想过使用numpy向量运算,但我还没有能够做到。我知道通过这种矢量化,我可以确保在另一个 for 循环中没有一个 for 循环,但我不知道实际上如何做到这一点,因为输出必须保持不变。到目前为止我有这个:
def filtered_lines_calculation(lines):
'''
This function is used for reducing all the hough-lines (from detect_grid) within a certain threshold to 1 line.
'''
if RESOLUTION == 0:
threshold = 75
elif RESOLUTION == 1:
threshold = 50
elif RESOLUTION == 2:
threshold = 30
lines = np.array(lines)
filtered_lines = []
# Calculate the orientation (horizontal or vertical)
slopes = (lines[:, :, 3] - lines[:, :, 1]) / (lines[:, :, 2] - lines[:, :, 0])
slopes[np.isinf(slopes)] = 1e6 # Avoid division by zero for vertical lines
for line_index, line in enumerate(lines):
p1, p2 = line[0][:2], line[0][2:]
too_close = False
slope = slopes[line_index]
# We can convert filtered_lines to a numpy array and process it faster
filtered_lines_array = np.array(filtered_lines)
for other_line in filtered_lines_array:
p3, p4 = other_line[:2], other_line[2:]
other_slope = (p4[1] - p3[1]) / (p4[0] - p3[0]) if (p4[0] - p3[0]) != 0 else 1e6
if (abs(slope) < 1 and abs(other_slope) < 1) or (abs(slope) > 1 and abs(other_slope) > 1):
distance = np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
if distance < threshold:
too_close = True
break
if not too_close:
filtered_lines.append(line[0])
return filtered_lines
其中 Lines 只是行列表,如下所示:
[[[ 0 40 211 47]]
[[ 0 91 211 98]]
[[ 6 92 211 99]]
[[ 27 42 211 49]]
[[173 1 184 118]]
[[ 19 118 34 0]]
[[ 75 118 81 1]]
[[ 0 39 191 46]]
[[129 0 131 118]]
[[ 76 118 82 0]]]
这是一个稍微修改过的 numba 版本 - 修改是现在函数返回索引列表(您可以使用此列表过滤原始列表
lines
以获得所需的结果)。
另一件事是,现在您还必须提供
np.array
作为函数参数:
from timeit import timeit
from numba import njit
from numba.np.extensions import cross2d
from numba.typed import List
def filtered_lines_calculation(lines, RESOLUTION):
"""
This function is used for reducing all the hough-lines (from detect_grid) within a certain threshold to 1 line.
"""
if RESOLUTION == 0:
threshold = 75
elif RESOLUTION == 1:
threshold = 50
elif RESOLUTION == 2:
threshold = 30
filtered_lines = []
# Calculate the orientation (horizontal or vertical)
slopes = (lines[:, :, 3] - lines[:, :, 1]) / (lines[:, :, 2] - lines[:, :, 0])
slopes[np.isinf(slopes)] = 1e6 # Avoid division by zero for vertical lines
for line_index, line in enumerate(lines):
p1, p2 = line[0][:2], line[0][2:]
too_close = False
slope = slopes[line_index]
for other_line in filtered_lines:
p3, p4 = other_line[:2], other_line[2:]
other_slope = (
(p4[1] - p3[1]) / (p4[0] - p3[0]) if (p4[0] - p3[0]) != 0 else 1e6
)
if (abs(slope) < 1 and abs(other_slope) < 1) or (
abs(slope) > 1 and abs(other_slope) > 1
):
distance = np.abs(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
if distance < threshold:
too_close = True
break
if not too_close:
filtered_lines.append(line[0])
return filtered_lines
@njit
def numba_norm(a):
return np.sqrt(a[0] * a[0] + a[1] * a[1])
@njit
def filtered_lines_calculation_numba(lines, RESOLUTION):
"""
This function is used for reducing all the hough-lines (from detect_grid) within a certain threshold to 1 line.
"""
if RESOLUTION == 0:
threshold = 75
elif RESOLUTION == 1:
threshold = 50
elif RESOLUTION == 2:
threshold = 30
filtered_lines = List.empty_list(np.int64)
slopes = (lines[:, :, 3] - lines[:, :, 1]) / (lines[:, :, 2] - lines[:, :, 0])
mask = np.isinf(slopes)[:, 0]
slopes[mask] = 1e6 # Avoid division by zero for vertical lines
for line_index, line in enumerate(lines):
p1, p2 = line[0][:2], line[0][2:]
too_close = False
slope = slopes[line_index]
for idx in filtered_lines:
other_line = lines[idx][0]
p3, p4 = other_line[:2], other_line[2:]
tmp = p4[0] - p3[0]
if tmp == 0:
tmp = 1e6
other_slope = (p4[1] - p3[1]) / tmp
if (np.abs(slope) < 1 and np.abs(other_slope) < 1) or (
np.abs(slope) > 1 and np.abs(other_slope) > 1
):
distance = np.abs(cross2d(p2 - p1, p1 - p3)) / numba_norm(p2 - p1)
if distance < threshold:
too_close = True
break
if not too_close:
filtered_lines.append(line_index)
return filtered_lines
RESOLUTION = 1
lines = np.array(
[
[[0, 40, 211, 47]],
[[0, 91, 211, 98]],
[[6, 92, 211, 99]],
[[27, 42, 211, 49]],
[[173, 1, 184, 118]],
[[19, 118, 34, 0]],
[[75, 118, 81, 1]],
[[0, 39, 191, 46]],
[[129, 0, 131, 118]],
[[76, 118, 82, 0]],
]
)
# check if the computation is correct + compile the numba function:
out1 = filtered_lines_calculation(lines, RESOLUTION)
out2 = filtered_lines_calculation_numba(lines, RESOLUTION)
assert all(np.allclose(i, j[0]) for i, j in zip(out1, lines[out2]))
# make bigger input to perform the test
lines = np.concatenate([lines for _ in range(10000)])
t1 = timeit(
"filtered_lines_calculation(lines, RESOLUTION)", number=1, globals=globals()
)
t2 = timeit(
"lines[filtered_lines_calculation_numba(lines, RESOLUTION)]",
number=1,
globals=globals(),
)
print(f"Time normal = {t1}")
print(f"Time numba = {t2}")
这在我的机器 AMD 5700x 上打印:
Time normal = 3.1929642129689455
Time numba = 0.032550607807934284