如何最有效地计算矩阵乘积的对角线[重复项]

问题描述 投票:0回答:1
我想计算以下内容:

import numpy as np n= 3 m = 2 x = np.random.randn(n,m) #Method 1 y = np.zeros(m) for i in range(m): y[i] = x[:,i] @ x[:,i] #Method 2 y2 = np.diag(x.T @ x)

第一种方法有一个问题,它使用for循环,效率不是很高(我需要在GPU上的py​​torch中这样做数百万次)

当我只需要对角线条目时,第二种方法将计算全矩阵乘积,因此也不是非常有效。

我想知道是否存在任何聪明的方法?

python numpy optimization pytorch linear-algebra
1个回答
1
投票
使用人工构造的和积。您需要各个列的平方和:

y = (x * x).sum(axis=0)

正如Divakar所建议的,np.einsum可能会提供较少的内存密集型选项,因为它不需要临时数组np.einsum

x * x

© www.soinside.com 2019 - 2024. All rights reserved.