在 numba 中连接 python 元组

问题描述 投票:0回答:2

我正在寻找用从一些元组中获取的数字填充零数组,就这么简单。

通常,即使元组的长度不同(这就是这里的重点),这也不是问题。但它似乎无法编译,我无法找到解决方案。

from numba import jit    

def cant_jit(ls):

    # Array total lenth
    tl = 6
    # Type
    typ = np.int64

    # Array to modify and return
    start = np.zeros((len(ls), tl), dtype=typ)

    for i in range(len(ls)):

        a = np.array((ls[i]), dtype=typ)
        z = np.zeros((tl - len(ls[i]),), dtype=typ)
        c = np.concatenate((a, z))
        start[i] = c

    return start

# Uneven tuples would be no problem in vanilla
cant_jit(((2, 4), (6, 8, 4)))


jt = jit(cant_jit)    
# working fine
jt(((2, 4), (6, 8)))
# non working
jt(((2, 4), (6, 8, 4)))

在错误范围内。

getitem(元组(UniTuple(int64 x 3),UniTuple(int64 x 2)),int64) 有 22 个候选实现: - 其中 22 个不匹配,原因是: 函数“getitem”的重载:文件::行 N/A。 带有参数:'(Tuple(UniTuple(int64 x 3), UniTuple(int64 x 2)), int64)': 没有匹配。

我在这里尝试了一些东西但没有成功。有人知道解决这个问题的方法,以便可以编译该函数并仍然执行它的操作吗?

python numpy concatenation numba jit
2个回答
1
投票

据我所知,这是不可能的,numba 文档告诉我们,除非您使用 forceobj=True,否则不等长的嵌套元组是不合法的。您甚至无法解压 *args,这令人沮丧。您将始终收到该警告/错误:

只需像这样将参数添加到 jit() 中:

from numba import jit    
import numpy as np

def cant_jit(ls):

    # Array total lenth
    tl = 6
    # Type
    typ = np.int64

    # Array to modify and return
    start = np.zeros((len(ls), tl), dtype=typ)

    for i in range(len(ls)):

        a = np.array((ls[i]), dtype=typ)
        z = np.zeros((tl - len(ls[i]),), dtype=typ)
        c = np.concatenate((a, z))
        start[i] = c

    return start

# Uneven tuples would be no problem in vanilla
cant_jit(((2, 4), (6, 8, 4)))


jt = jit(cant_jit, forceobj=True)    
# working fine
jt(((2, 4), (6, 8)))
# now working
jt(((2, 4), (6, 8, 4)))

这行得通,但有点毫无意义,你也可以使用核心 python。


0
投票

我想知道

numba
是否会更喜欢这个非 numpy 版本:

def foo1(ls):
    res = []
    for row in ls:
        res.append(row+((0,)*(6-len(ls))))
    return res
© www.soinside.com 2019 - 2024. All rights reserved.