使用内部带有for循环的jitted函数

问题描述 投票:-1回答:1

我有以下代码

@njit
    def ss(w,g,sm):
        kb = ( (α/(r*(1+τk)))**((1-γ)/(1-γ - α)) )* \
        ( (γ/(w*(1+τn)))**(γ/(1-γ-α)) )* \
        (smat*(1-τo))**(1/(1-γ-α)) ## ss capital

        nb = (1+τk)*r*γ/((1+τn)*w*α)*kb ## ss labor

        πb = (1-τo)*sm*(kb**α)*(nb**γ)- (1+τn)*w*nb-(1+τk)*r*kb-cf #ss profit
        W = πb/(1-0.0196) ## error in the code
        for i in range(ns):
            for j in range(nτ):
                xb[i,j] = 1 if W[i,j]>=0 else xb[i,j]
        we = sum(W*g*xb) - ce
        return we

据我所知它应该可以工作,但是我不断收到以下错误

TypingError: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (readonly array(float64, 2d, C), tuple(int64 x 2), float64)
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:

我仍在想尽办法解决python。我知道这是我正在使用njit的事实,但是到底是什么原因引起的呢?如果我删除了for循环,它可以正常工作,但是我想知道for循环中的原因是什么?

可复制的示例

import numpy as np
from numba import njit
class rep:
    def __init__(self, A = 1, B=2, C = 3, D = 4):
        self.A, self.B,self.C, self.D = A,B,C,D


def op(cls):
    A,B,C,D = cls.A, cls.B,cls.C, cls.D
    xb = np.zeros([10,10])
    @njit
    def rep1(w,g,sm):
        kb = A*3 + B*2
        nb = C*3 - D*w
        pib = sm*kb - nb
        W = pib/4
        for i in range(10):
            for j in range(10):
                xb[i,j] = 1 if W[i,j]>=0 else xb[i,j]
        we = np.sum(w*xb*g)
        return we
    return rep1


g = np.zeros([10,10])
sm = np.ones([10,10])
w = 1
r = rep()
rep1 = op(r)
print(rep1(w,g,sm))
python loops for-loop jit numba
1个回答
0
投票

这似乎与注释中提到的@ max9111有关-xb数组是在jitted函数外部创建的,因此无法在内部进行修改。我通过将xb移到op之外并将其作为参数传递给rep1来稍微更改了代码,可以成功执行它:

import numpy as np
from numba import njit
class rep:
    def __init__(self, A = 1, B=2, C = 3, D = 4):
        self.A, self.B,self.C, self.D = A,B,C,D


def op(cls):
    A,B,C,D = cls.A, cls.B,cls.C, cls.D
    @njit
    def rep1(w,g,sm,xb):
        kb = A*3 + B*2
        nb = C*3 - D*w
        pib = sm*kb - nb
        W = pib/4
        for i in range(10):
            for j in range(10):
                xb[i,j] = 1 if W[i,j]>=0 else xb[i,j]
        we = np.sum(w*xb*g)
        return we
    return rep1


g = np.zeros([10,10])
sm = np.ones([10,10])
xb = np.zeros([10,10])
w = 1
r = rep()
rep1 = op(r)
print(rep1(w,g,sm,xb))
© www.soinside.com 2019 - 2024. All rights reserved.