不支持Numba多维索引

问题描述 投票:0回答:0

我正在运行代码来模拟经济模型。该代码使用 numpy 并且执行没有错误。我试图通过在特定函数中包含 numba 的“njit()”装饰器来加快性能。令我惊讶的是,代码现在产生了一个不应该出现的错误。当我应用函数 c[c<2000] = 2000 a multidimensional index error appears. The wierd part of this is that I perform this operation in two different parts of the code, and in the first one it works without any error, but in the second I get a multidimensional index problem. The code is the following. The error comes from the one before the last one line of the code.

@njit()
def get_utility(x1,x1_new,x2,b,b1,e,j,period,param_g):
    if j[1] == 0:  # the individual does not study and no max is needed.
        w = wage(x1_new,x2)*(j[2]/2)   # adjust wages for labor supply decision
        w_vis = np.repeat(w,np.shape(b)[0]*np.shape(e)[0])
        b_vis = numba_tile_new(b,np.shape(w)[0]*np.shape(e)[0])
        e_vis = np.repeat(e,np.shape(w)[0]*np.shape(b)[0])
        
        c = (w_vis-(1+r)*b_vis+e_vis+repayment(b_vis))
        c[c<2000]  = 2000
        u = get_power_utility(c)
        # Create indicator for which choice
        
        pg = get_param_g(j,param_g).astype("float64")
            
        # Include a constant to g()
        
        # Create x1 with polynomials. 
        
        x1_poli = get_x1_poli(x1)
        
        # get g function
        g = x1_poli@pg
        
        payoff =  u+ 0.1*g
        
        return payoff[...,np.newaxis]

    else: 
        h = fin_help(x1_new,x2,j,period)
        h_vis = np.repeat(h,np.shape(b)[0]*np.shape(e)[0])
        
        w = wage(x1_new,x2)*(j[2]/2)   # adjust wages for labor supply decision
        w_vis = np.repeat(w,np.shape(b)[0]*np.shape(e)[0])
        
        b_vis = numba_tile_new(b,np.shape(h)[0]*np.shape(e)[0])
        e_vis = np.repeat(e,np.shape(h)[0]*np.shape(b)[0])
        c = (h_vis-(1+r)*b_vis-tuition(j)+e_vis+ w_vis)
        c =c[...,np.newaxis] + b1
        
        c[c<2000]  = 2000
        
        return c

产生的错误是:

`TypingError:找不到签名的函数 Function() 的实现:

setitem(数组(float64, 2d, C), 数组(bool, 2d, C), Literalint)

有 16 个候选实现: - 其中 14 个不匹配,原因是: 函数“setitem”重载:文件::行不适用。 带参数:'(array(float64, 2d, C), array(bool, 2d, C), int64)': 没有匹配。 - 其中 2 个不匹配,原因是: 函数“SetItemBuffer.generic”中的重载:文件:numb

python numpy multidimensional-array numba indices
© www.soinside.com 2019 - 2024. All rights reserved.