我正在处理一个简单的问题与numpy。我有两个矩阵列表 - 比如A,B
- 分别编码为具有形状(n,p,q)
和(n,q,r)
的3D数组。
我想计算他们的元素点积,这是一个三维数组C
,使C[i,j,l] = sum A[i,j,:] B[i,:,l]
。从数学上讲,这非常简单,但这是我必须遵循的规则:
1)我必须只使用numpy函数(dot
,tensordot
,einsum
等):no loop&cie。这是因为我希望这可以在我的gpu上工作(带有杯状),并且循环很糟糕。我希望在当前设备上进行所有操作。
2)由于我的数据可能非常大,通常A
和B
已经在内存中占用了几十个Mb,我不想构建任何形状比(n,p,q),(n,q,r),(n,p,r)
更大的项目(不必存储中间4D数组)。
例如,我找到there的解决方案,即使用:
C = np.sum(np.transpose(A,(0,2,1)).reshape(n,p,q,1)*B.reshape(n,q,1,r),-3)
从数学的角度讲是正确的,但暗示了(n,p,q,r)数组的中间创建,这对我来说太大了。
我遇到类似的问题
C = np.einsum('ipq,iqr->ipr',A,B)
我不知道什么是底层操作和构造,但它总是导致内存错误。
另一方面,有点像天真的东西:
C = np.array([A[i].dot(B[i]) for i in range(n)])
在内存方面似乎没问题,但在我的gpu上效率不高:列表是在CPU上构建的,并且将它重新分配给gpu很慢(如果有一个友好的方式来编写它,它将是一个好的解决方案!)
谢谢您的帮助 !
你想要numpy.matmul
(cupy version here)。 matmul
是一个“广播”矩阵倍增。
我认为人们已经知道numpy.dot
语义是不稳定的,并且需要广播矩阵乘法,但是在python获得@
运算符之前引入变化的动力并不大。我没有看到dot
去任何地方,但我怀疑更好的语义和做A @ B
的容易性将意味着当人们发现新的功能和运营商时,dot
将失宠。
您寻求避免的迭代方法可能不会那么糟糕。例如,考虑这些时间:
In [51]: A = np.ones((100,10,10))
In [52]: timeit np.array([A[i].dot(A[i]) for i in range(A.shape[0])])
439 µs ± 1.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [53]: timeit np.einsum('ipq,iqr->ipr',A,A)
428 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [54]: timeit A@A
426 µs ± 54.6 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
对于这种情况,所有三个大约需要同一时间。
但是我将后续维度加倍,迭代方法实际上更快:
In [55]: A = np.ones((100,20,20))
In [56]: timeit np.array([A[i].dot(A[i]) for i in range(A.shape[0])])
702 µs ± 1.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [57]: timeit np.einsum('ipq,iqr->ipr',A,A)
1.89 ms ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [58]: timeit A@A
1.89 ms ± 490 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
当我改变20到30和40时,同样的模式成立。我有点惊讶matmul
时代与einsum
如此接近。
我想我可以尝试将这些推到内存限制。我没有一个花哨的后端来测试这个方面。
一旦考虑到内存管理问题,对大问题的少量迭代就不那么可怕了。在numpy中,你想要避免的事情是在一个简单的任务上进行多次迭代。