为什么numba不能提高背包功能的速度?

问题描述 投票:1回答:2

我试图使用numba加快我的代码的速度,但似乎不起作用。该程序与@jit@njit或纯python花费相同的时间(约10秒)。但是我使用了numpy而不是list或dict。

这里是我的代码:

import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)

@njit
def knapSack(W, wt, val, n):
    K = np.full((n+1,W+1),0)
    N =  np.full((n+1,W+1,W+1),0)
    M =  np.full((n+1,W+1),0)

    for i in range(n+1):
        for w in range(W+1):
            if i==0 or w==0:
                K[i][w] = 0
            elif wt[i-1] <= w:
                if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                    K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                    c = N[i-1][w-wt[i-1]]
                    c[i] = i
                    N[i][w] = c
                else:
                    K[i][w] = K[i-1][w]
                    N[i][w] = N[i-1][w]
            else:
                K[i][w] = K[i-1][w]
    N[n][W][0] = K[n][W]
    return N[n][W]

@profile
def main():

    size = 1000
    val = [random.randint(1, size) for i in range(0, size)]
    wt = [random.randint(1, size) for i in range(0, size)]
    W = 1000
    n = len(val)
    a = knapSack(W, wt, val, n)
main()

python python-3.x performance jit numba
2个回答
1
投票

实际上,如果不更改方法本身,可能无法真正提高当前算法的性能。

您的N数组包含大约10亿个对象(1001 * 1001 * 1001)。您需要设置每个元素,因此您至少要进行十亿次操作。为了获得一个下限,我们假设设置一个数组元素需要一纳秒(实际上会花费更多的时间)。 10亿次操作,每个操作需要1纳秒,这意味着需要1秒才能完成。正如我说的那样,每次操作可能会花费超过1纳秒的时间,因此我们假设它花费10纳秒(可能有点高,但比1纳秒更现实),这意味着该算法总共需要10秒钟。

因此,您的输入的预期运行时间将在1秒到10秒之间。因此,如果您的Python版本需要10秒钟,则可能已经是您选择的方法所能达到的极限,并且没有任何工具能够(显着)改善运行时间。


可能使速度更快的一件事是使用np.zeros而不是np.full

K = np.zeros((n+1, W+1), dtype=int)
N = np.zeros((n+1, W+1, W+1), dtype=int)

并且不要创建M,因为您将不使用它。


由于您已经使用过line-profiler,所以我决定看一下,得到了这个结果:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def knapSack(W, wt, val, n):
     4         1      19137.0  19137.0      0.0      K = np.full((n+1,W+1),0)
     5         1   19408592.0 19408592.0     28.1      N = np.full((n+1,W+1,W+1),0)
     6                                           
     7      1002       6412.0      6.4      0.0      for i in range(n+1):
     8   1003002    4186311.0      4.2      6.1          for w in range(W+1):
     9   1002001    4644031.0      4.6      6.7              if i==0 or w==0:
    10      2001      19663.0      9.8      0.0                  K[i][w] = 0
    11   1000000    5474080.0      5.5      7.9              elif wt[i-1] <= w:
    12    498365    9616406.0     19.3     13.9                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
    13     52596     902030.0     17.2      1.3                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
    14     52596     578740.0     11.0      0.8                      c = N[i-1][w-wt[i-1]]
    15     52596     295980.0      5.6      0.4                      c[i] = i
    16     52596    1239792.0     23.6      1.8                      N[i][w] = c
    17                                                           else:
    18    445769    5100917.0     11.4      7.4                      K[i][w] = K[i-1][w]
    19    445769   11677683.0     26.2     16.9                      N[i][w] = N[i-1][w]
    20                                                       else:
    21    501635    5801328.0     11.6      8.4                  K[i][w] = K[i-1][w]
    22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
    23         1         14.0     14.0      0.0      return N[n][W]

这表明瓶颈是np.fullN[i][w] = N[i-1][w]if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])。 Numba不会改进前两个代码,因为它们已经使用了高度优化的NumPy代码,而numba可能会更慢一些。 Numba可能可以改善if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]),但可能不会引起注意。

如果np.fullnp.zeros替换,则配置文件会稍有变化:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def knapSack(W, wt, val, n):
     4         1        747.0    747.0      0.0      K = np.zeros((n+1, W+1),dtype=int)
     5         1     109592.0 109592.0      0.2      N = np.zeros((n+1, W+1, W+1),dtype=int)
     6                                           
     7      1002       4230.0      4.2      0.0      for i in range(n+1):
     8   1003002    4414071.0      4.4      7.0          for w in range(W+1):
     9   1002001    4836807.0      4.8      7.7              if i==0 or w==0:
    10      2001      22282.0     11.1      0.0                  K[i][w] = 0
    11   1000000    5646859.0      5.6      8.9              elif wt[i-1] <= w:
    12    521222   10389581.0     19.9     16.5                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
    13     47579     784563.0     16.5      1.2                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
    14     47579     509056.0     10.7      0.8                      c = N[i-1][w-wt[i-1]]
    15     47579     362796.0      7.6      0.6                      c[i] = i
    16     47579    1975916.0     41.5      3.1                      N[i][w] = c
    17                                                           else:
    18    473643    5579823.0     11.8      8.8                      K[i][w] = K[i-1][w]
    19    473643   22805846.0     48.1     36.1                      N[i][w] = N[i-1][w]
    20                                                       else:
    21    478778    5664271.0     11.8      9.0                  K[i][w] = K[i-1][w]
    22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
    23         1         10.0     10.0      0.0      return N[n][W]

但是主要瓶颈仍然是N[i][w] = N[i-1][w],使用numba可能比使用纯NumPy慢。因此,再次使用numba对代码的其他一些部分所做的改进可能不会引起注意(再次)。


对于第一个配置文件,我使用了此版本的代码(第二个配置文件只是将np.full更改为np.zeros):

import numpy as np

def knapSack(W, wt, val, n):
    K = np.full((n+1,W+1),0)
    N = np.full((n+1,W+1,W+1),0)

    for i in range(n+1):
        for w in range(W+1):
            if i==0 or w==0:
                K[i][w] = 0
            elif wt[i-1] <= w:
                if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                    K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                    c = N[i-1][w-wt[i-1]]
                    c[i] = i
                    N[i][w] = c
                else:
                    K[i][w] = K[i-1][w]
                    N[i][w] = N[i-1][w]
            else:
                K[i][w] = K[i-1][w]
    N[n][W][0] = K[n][W]
    return N[n][W]

import random
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)

%lprun -f knapSack knapSack(W, wt, val, n)

0
投票

这里是新功能:

 @njit
    def knapSack(W, wt, val, n):

        K = np.zeros((n + 1, W + 1),dtype=np.int32)
        # In fact we must only save the previous combinations and the current, 
        # not all :) So N is considerably reduce
        N = np.zeros((2, W + 1, W + 1),dtype=np.int32)

        for i in range(n + 1):
            for w in range(W + 1):
                if i == 0 or w == 0:
                    K[i][w] = 0
                elif wt[i - 1] <= w:
                    if val[i - 1] + K[i - 1][w - wt[i - 1]] > K[i - 1][w]:
                        K[i][w] = val[i - 1] + K[i - 1][w - wt[i - 1]]
                        N[i%2][w] = np.copy(N[(i - 1)%2][w - wt[i - 1]])
                        N[i%2][w][i] = i
                    else:
                        K[i][w] = K[i - 1][w]
                        N[i%2][w] = N[(i - 1)%2][w]
                else:
                    K[i][w] = K[i - 1][w]
        N[(n)%2][W][0] = K[n][W]
        return N[(n)%2][W]

非常感谢您 MSeifert !!

© www.soinside.com 2019 - 2024. All rights reserved.