Numpy中的矩阵乘法耗时过长

问题描述 投票:0回答:1

我想用numpy在Python中简单地实现一个损失函数(MSE),这是我的代码。

import numpy as np

def loss(X, y, w):
    N = (X.shape)[0]
    X_new = np.concatenate((np.ones((N, 1)), X), axis=1)
    E = y-np.matmul(X_new, w)
    E_t = np.transpose(E)
    loss_value = (1/N)*(np.matmul(E_t, E))
    return loss_value

我的代码是这样的: E 是(15000,1),并且 E_t 显然是(1,15000)。然而,在调试时,我发现 np.matmul(E_t,E) 需要太多时间。我有一台16GB内存和酷睿i7的笔记本电脑,所以对我来说,这很奇怪。np.matmul 在这里失败了。如果我处理的矩阵有这些维度,这是否正常?

numpy
1个回答
0
投票

在一个相当基本的4GB机器上。

In [477]: E=np.ones((15000, 1))                                                                        
In [478]: E.T@E                                                                                        
Out[478]: array([[15000.]])
In [479]: timeit E.T@E                                                                                 
10.5 µs ± 241 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

你没有告诉我们任何关于 X但假设最坏的情况。

In [480]: E=np.ones((15000, 1),object)                                                                 
In [481]: E.T@E                                                                                        
Out[481]: array([[15000]], dtype=object)
In [482]: timeit E.T@E                                                                                 
577 µs ± 492 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.