4 维张量内沿给定轴的 Sub-2x2 矩阵乘积

问题描述 投票:0回答:1

我有一个 4 维的 ndarray 结构:(K,N,2,2)。您可以将其想象为 K 个不同的堆栈,每个堆栈包含 N 个维度为 2x2 的矩阵。 对于每个堆栈,我尝试计算其 N 个矩阵的矩阵乘积(没有 For 循环)。所以最后,我应该有 K 个 2x2 维度的矩阵(N 个 2x2 矩阵的乘积仍然是 2x2 矩阵)。

如果我的数组的维度为 (N,2,2),则使用以下函数沿第一个轴执行矩阵乘积非常简单:

A_total = np.linalg.multi_dot(A)  # A being a (N,2,2) array

但是对于 (K,N,2,2) 结构,我无法在忽略第一个轴的情况下沿第二个轴执行相同的操作。

您对此有何建议?我尝试过 np.einsum() 和 np.tensordot() 但不太明白如何正确使用这些函数。

python numpy matrix-multiplication tensor
1个回答
0
投票

让我们尝试一下小 K=3、N=5 的 2 个明显替代方案

In [95]: A = np.random.rand(3,5,2,2)

首先迭代 K,堆栈:

In [96]: res = np.array([np.linalg.multi_dot(A[i,...]) for i in range(A.shape[0])])

In [97]: res.shape
Out[97]: (3, 2, 2)

现在链上N:

In [98]: res1=A[:,0,:,:].copy()
    ...: for j in range(1,A.shape[1]):
    ...:     res1 = res1@A[:,j,:,:]
    ...:     

他们匹配:

In [99]: np.allclose(res, res1)
Out[99]: True

迭代链的时机要好得多:

In [100]: timeit res = np.array([np.linalg.multi_dot(A[i,...]) for i in range(A.shape[0])])
329 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [101]: %%timeit 
     ...: res1=A[:,0,:,:].copy()
     ...: for j in range(1,A.shape[1]):
     ...:     res1 = res1@A[:,j,:,:]
     ...:     
33.7 µs ± 767 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

因此,除非 K/N 比率非常不同,否则我会坚持重复的

matmul
matmul
在(3+)的引导尺寸上是线性的,但该迭代是在编译代码中。

© www.soinside.com 2019 - 2024. All rights reserved.