我有下面的一段代码,用于计算一组批处理特征上的马哈拉诺比斯距离,在我的设备上大约需要 100 毫秒,其中大部分是由于 delta 和 inv_covariance 之间的矩阵乘法造成的
delta 是维度为 874x32x100 的矩阵,inv_covariance 维度为 874x100x100
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
distances = ((delta @ inv_covariance) * delta).sum(2).transpose(1, 0)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
我尝试将代码转换为使用 numba 和 @njit,我已经预先分配了中间矩阵,并且我尝试使用 for 循环执行较小的矩阵乘法,因为 3 维矩阵不支持 matmul。
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
inv_covariance = np.ascontiguousarray(inv_covariance)
intermediate_matrix = np.zeros_like(delta)
for i in range(intermediate_matrix.shape[0]):
intermediate_matrix[i] = delta[i] @ inv_covariance[i]
distances = (intermediate_matrix * delta).sum(2).transpose(1, 0)
distances = np.ascontiguousarray(distances)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
我添加了一些ascontigiousarray,最后一个很重要或者代码不起作用,其他的用来抑制警告说@会执行得更快(看起来并不算太多)。
有没有办法让代码更快,要么通过改进它,要么以不同的数学方式重新思考?
首先,矩阵乘法是由名为 BLAS 的库完成的,并且大多数实现都是高效的并行实现。话虽这么说,对于批量小矩阵,并行实现不可能那么高效。事实上,粒度太小,因此使用多线程的开销变得很大。最好并行化外循环并使用顺序矩阵乘法代码。 由于矩阵乘法涉及的矩阵非常小,因此最好
手动重新实现矩阵乘法。事实上,这消除了调用矩阵乘法库(BLAS)函数的开销,并且还确保在矩阵乘法期间不使用线程。不过,我们需要关心连续读取/写入值,因此该操作是SIMD 友好。 最重要的是,矩阵乘法可以与下一行合并
(intermediate_matrix * delta).sum(2)
,以便写入较小的输出数组并避免读回大的临时数组。这很重要,因为 RAM 速度很慢
。此策略还减少了内存占用,同时速度更快且可扩展性更好。尽管我没有测试它,但将操作与行
(embedding - mean).transpose(2, 0, 1)
合并当然是个好主意。
实施
@nb.njit()
def matmul(delta, inv_covariance):
si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
assert sk == inv_covariance.shape[0]
line = np.zeros(sj, dtype=delta.dtype)
res = np.zeros(si, dtype=delta.dtype)
for i in range(si):
line.fill(0.0)
for k in range(sk):
factor = delta[i, k]
for j in range(sj):
line[j] += factor * inv_covariance[k, j]
for j in range(sj):
res[i] += line[j] * delta[i, j]
return res
@nb.njit(parallel=True)
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
inv_covariance = np.ascontiguousarray(inv_covariance)
intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]))
for i in nb.prange(intermediate_matrix.shape[0]):
intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])
distances = intermediate_matrix.transpose(1, 0)
distances = np.ascontiguousarray(distances)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
结果
np.allclose
)。