我试图使用numba加快我的代码的速度,但似乎不起作用。该程序与@jit
,@njit
或纯python花费相同的时间(约10秒)。但是我使用了numpy而不是list或dict。
这里是我的代码:
import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)
@njit
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
M = np.full((n+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
@profile
def main():
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
a = knapSack(W, wt, val, n)
main()
实际上,如果不更改方法本身,可能无法真正提高当前算法的性能。
您的N
数组包含大约10亿个对象(1001 * 1001 * 1001
)。您需要设置每个元素,因此您至少要进行十亿次操作。为了获得一个下限,我们假设设置一个数组元素需要一纳秒(实际上会花费更多的时间)。 10亿次操作,每个操作需要1纳秒,这意味着需要1秒才能完成。正如我说的那样,每次操作可能会花费超过1纳秒的时间,因此我们假设它花费10纳秒(可能有点高,但比1纳秒更现实),这意味着该算法总共需要10秒钟。
因此,您的输入的预期运行时间将在1秒到10秒之间。因此,如果您的Python版本需要10秒钟,则可能已经是您选择的方法所能达到的极限,并且没有任何工具能够(显着)改善运行时间。
可能使速度更快的一件事是使用np.zeros
而不是np.full
:
K = np.zeros((n+1, W+1), dtype=int)
N = np.zeros((n+1, W+1, W+1), dtype=int)
并且不要创建M
,因为您将不使用它。
由于您已经使用过line-profiler,所以我决定看一下,得到了这个结果:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 19137.0 19137.0 0.0 K = np.full((n+1,W+1),0)
5 1 19408592.0 19408592.0 28.1 N = np.full((n+1,W+1,W+1),0)
6
7 1002 6412.0 6.4 0.0 for i in range(n+1):
8 1003002 4186311.0 4.2 6.1 for w in range(W+1):
9 1002001 4644031.0 4.6 6.7 if i==0 or w==0:
10 2001 19663.0 9.8 0.0 K[i][w] = 0
11 1000000 5474080.0 5.5 7.9 elif wt[i-1] <= w:
12 498365 9616406.0 19.3 13.9 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 52596 902030.0 17.2 1.3 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 52596 578740.0 11.0 0.8 c = N[i-1][w-wt[i-1]]
15 52596 295980.0 5.6 0.4 c[i] = i
16 52596 1239792.0 23.6 1.8 N[i][w] = c
17 else:
18 445769 5100917.0 11.4 7.4 K[i][w] = K[i-1][w]
19 445769 11677683.0 26.2 16.9 N[i][w] = N[i-1][w]
20 else:
21 501635 5801328.0 11.6 8.4 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 14.0 14.0 0.0 return N[n][W]
这表明瓶颈是np.full
,N[i][w] = N[i-1][w]
和if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
。 Numba不会改进前两个代码,因为它们已经使用了高度优化的NumPy代码,而numba可能会更慢一些。 Numba可能可以改善if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
,但可能不会引起注意。
如果np.full
被np.zeros
替换,则配置文件会稍有变化:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 747.0 747.0 0.0 K = np.zeros((n+1, W+1),dtype=int)
5 1 109592.0 109592.0 0.2 N = np.zeros((n+1, W+1, W+1),dtype=int)
6
7 1002 4230.0 4.2 0.0 for i in range(n+1):
8 1003002 4414071.0 4.4 7.0 for w in range(W+1):
9 1002001 4836807.0 4.8 7.7 if i==0 or w==0:
10 2001 22282.0 11.1 0.0 K[i][w] = 0
11 1000000 5646859.0 5.6 8.9 elif wt[i-1] <= w:
12 521222 10389581.0 19.9 16.5 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 47579 784563.0 16.5 1.2 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 47579 509056.0 10.7 0.8 c = N[i-1][w-wt[i-1]]
15 47579 362796.0 7.6 0.6 c[i] = i
16 47579 1975916.0 41.5 3.1 N[i][w] = c
17 else:
18 473643 5579823.0 11.8 8.8 K[i][w] = K[i-1][w]
19 473643 22805846.0 48.1 36.1 N[i][w] = N[i-1][w]
20 else:
21 478778 5664271.0 11.8 9.0 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 10.0 10.0 0.0 return N[n][W]
但是主要瓶颈仍然是N[i][w] = N[i-1][w]
,使用numba可能比使用纯NumPy慢。因此,再次使用numba对代码的其他一些部分所做的改进可能不会引起注意(再次)。
对于第一个配置文件,我使用了此版本的代码(第二个配置文件只是将np.full
更改为np.zeros
):
import numpy as np
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
import random
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
%lprun -f knapSack knapSack(W, wt, val, n)
这里是新功能:
@njit
def knapSack(W, wt, val, n):
K = np.zeros((n + 1, W + 1),dtype=np.int32)
# In fact we must only save the previous combinations and the current,
# not all :) So N is considerably reduce
N = np.zeros((2, W + 1, W + 1),dtype=np.int32)
for i in range(n + 1):
for w in range(W + 1):
if i == 0 or w == 0:
K[i][w] = 0
elif wt[i - 1] <= w:
if val[i - 1] + K[i - 1][w - wt[i - 1]] > K[i - 1][w]:
K[i][w] = val[i - 1] + K[i - 1][w - wt[i - 1]]
N[i%2][w] = np.copy(N[(i - 1)%2][w - wt[i - 1]])
N[i%2][w][i] = i
else:
K[i][w] = K[i - 1][w]
N[i%2][w] = N[(i - 1)%2][w]
else:
K[i][w] = K[i - 1][w]
N[(n)%2][W][0] = K[n][W]
return N[(n)%2][W]
非常感谢您 MSeifert !!