这是我的MWE:
from numba import njit
import numpy as np
@njit
def solve(n):
count = np.zeros(n + 1, dtype=int)
res = np.array([0], dtype=int)
def search(sz=0, max_val=1, single=0, previous=None):
nonlocal res
if sz == 4 * n:
res[0] += 1
return
if single and count[0] < 2 * n:
count[0] += 1
search(sz + 1, max_val, single)
count[0] -= 1
for i in range(1, max_val + 1):
if i != previous and count[i] < 2:
count[i] += 1
search(sz + 1, max_val + (i == max_val and max_val < n), single + (count[i] == 1) - (count[i] == 2), i)
count[i] -= 1
search()
return res[0]
for i in range(1, 6):
print(solve(i))
这给出:
NotImplementedError: Failed in nopython mode pipeline (step: analyzing bytecode)
Unsupported use of op_LOAD_CLOSURE encountered
让它与 numba 一起工作的正确方法是什么?如果删除 @njit 行,代码会正确运行,但运行速度会很慢。
我将
@njit
放入内部函数:
import numpy as np
from numba import njit
def solve(n):
@njit
def search(count, res, sz=0, max_val=1, single=0, previous=None):
if sz == 4 * n:
res[0] += 1
return
if single and count[0] < 2 * n:
count[0] += 1
search(count, res, sz + 1, max_val, single)
count[0] -= 1
for i in range(1, max_val + 1):
if i != previous and count[i] < 2:
count[i] += 1
search(
count,
res,
sz + 1,
max_val + (i == max_val and max_val < n),
single + (count[i] == 1) - (count[i] == 2),
i,
)
count[i] -= 1
count = np.zeros(n + 1, dtype=int)
res = np.array([0], dtype=int)
search(count, res)
return res[0]
for i in range(1, 6):
print(solve(i))