Stack Overflow 社区您好,
我正在使用 NumPy 进行矩阵运算,我有一个关于 NumPy 如何处理矩阵乘法的问题,特别是在处理不连续的矩阵切片时。
考虑一个场景,我们有一个大矩阵,例如大小为 [1000, 1000],并且我们需要使用步骤对该矩阵的切片版本执行矩阵乘法,例如 [::10, ::10]。据我了解,NumPy 可能使用优化的 BLAS 例程(如
GEMM
)进行底层矩阵乘法。然而,BLAS 例程通常需要连续的内存布局才能有效运行。
我的问题是:NumPy 内部如何处理由于步骤切片而导致乘法输入矩阵不连续的情况?具体来说,我有兴趣了解 NumPy:
GEMM
。这些信息将帮助我更好地理解在 NumPy 中使用切片和矩阵乘法步骤对性能的影响。
预先感谢您的见解!
np.matmul
做了相当多的工作试图弄清楚何时可以将工作传递给 BLAS。实现它的主要源文件是numpy/_core/src/umath/matmul.c.src
,具体看一下@TYPE@_matmul()
和is_blasable2d()
.
具体来说,
is_blasable2d
上的评论检查:
- 步幅不得混叠或重叠
- 较快(第二)轴必须是连续的
- 较慢的(第一)轴步幅(以单位步长计)必须大于 快轴尺寸
因此,由于第二个约束,您的示例应使用较慢的
_noblas
变体,即第二个轴不连续。
作为健全性检查,我们查看运行时间是否与上述观察结果一致:
import numpy as np
arr = np.zeros((1000, 1000))
%timeit arr[::2,::2] @ arr[::2,::2] # takes ~300ms
%timeit arr[::2,:500] @ arr[::2,:500] # takes ~ 4ms
%timeit arr[:500,:500] @ arr[:500,:500] # takes ~ 4ms
# as pointed out by hpaulj, the following takes ~ 5ms
%timeit np.ascontiguousarray(arr[::2,::2]) @ np.ascontiguousarray(arr[::2,::2])
这似乎是正确的,第一个变体有一个不连续的第二轴,并且速度慢得多,大概是因为它没有使用 BLAS。其他变体可能更快,因为它们被传递给 BLAS。制作连续的副本需要一些时间,但生成的运行时间更快,因此在必要时这样做看起来是值得的。