NumPy 内部如何处理非连续切片的矩阵乘法?

问题描述 投票:0回答:1

Stack Overflow 社区您好,

我正在使用 NumPy 进行矩阵运算,我有一个关于 NumPy 如何处理矩阵乘法的问题,特别是在处理不连续的矩阵切片时。

考虑一个场景,我们有一个大矩阵,例如大小为 [1000, 1000],并且我们需要使用步骤对该矩阵的切片版本执行矩阵乘法,例如 [::10, ::10]。据我了解,NumPy 可能使用优化的 BLAS 例程(如

GEMM
)进行底层矩阵乘法。然而,BLAS 例程通常需要连续的内存布局才能有效运行。

我的问题是:NumPy 内部如何处理由于步骤切片而导致乘法输入矩阵不连续的情况?具体来说,我有兴趣了解 NumPy:

  1. 自动将这些切片重新分配到新的连续内存块,然后执行
    GEMM
  2. 具有处理非连续切片的优化方法,无需重新分配内存。
  3. 使用 BLAS 例程的任何特定变体或 NumPy 自己的实现来处理此类情况。

这些信息将帮助我更好地理解在 NumPy 中使用切片和矩阵乘法步骤对性能的影响。

预先感谢您的见解!

python numpy blas
1个回答
0
投票

np.matmul
做了相当多的工作试图弄清楚何时可以将工作传递给 BLAS。实现它的主要源文件是
numpy/_core/src/umath/matmul.c.src
,具体看一下
@TYPE@_matmul()
is_blasable2d()
.

具体来说,

is_blasable2d
上的评论检查:

  1. 步幅不得混叠或重叠
  2. 较快(第二)轴必须是连续的
  3. 较慢的(第一)轴步幅(以单位步长计)必须大于 快轴尺寸

因此,由于第二个约束,您的示例应使用较慢的

_noblas
变体,即第二个轴不连续。

作为健全性检查,我们查看运行时间是否与上述观察结果一致:

import numpy as np

arr = np.zeros((1000, 1000))

%timeit arr[::2,::2] @ arr[::2,::2]     # takes ~300ms
%timeit arr[::2,:500] @ arr[::2,:500]   # takes ~  4ms
%timeit arr[:500,:500] @ arr[:500,:500] # takes ~  4ms

# as pointed out by hpaulj, the following takes ~  5ms
%timeit np.ascontiguousarray(arr[::2,::2]) @ np.ascontiguousarray(arr[::2,::2])

这似乎是正确的,第一个变体有一个不连续的第二轴,并且速度慢得多,大概是因为它没有使用 BLAS。其他变体可能更快,因为它们被传递给 BLAS。制作连续的副本需要一些时间,但生成的运行时间更快,因此在必要时这样做看起来是值得的。

© www.soinside.com 2019 - 2024. All rights reserved.