如何在 Numba 的 np.repeat() 中重新创建 axis = 2

问题描述 投票:0回答:1

我正在尝试将我的代码转换为 Numba 友好的实现,但是我不断遇到 axis 参数的错误(因为它不受支持)。具体来说,我需要在 axis=2 中使用 np.repeat() 函数

在 numpy 中我的代码是:

original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims((5)*original, axis=2), no_repeats, axis=2)

我如何以 Numba 友好的方式重写它?

我尝试过使用np.dstack:

expanded_original = np.expand_dims((5)*original, axis=2)
big_original = np.dstack([expandedGradientMatrix]*no_repeats)

但是列表当然不是受支持的数据类型。我怎样才能以最有效的方式解决这个问题?

python numpy performance optimization numba
1个回答
0
投票

我不知道你到底想做什么,但我猜你想在

big_original
numba 编译函数中重现
@njit
数组。对吗?

我是这样:

@njit
def repeat_original(original, no_repeats):
    big_original = np.zeros((*original.shape, no_repeats))
    for i in range(big_original.shape[-1]):
        big_original[...,i] = (5)*original
    return big_original
repeat_original(original, no_repeats)

如果这不是您期望的答案,请尝试更好地说明您的问题(例如

expandedGradientMatrix
是什么)和您的预期输出。

© www.soinside.com 2019 - 2024. All rights reserved.