我正在尝试将我的代码转换为 Numba 友好的实现,但是我不断遇到 axis 参数的错误(因为它不受支持)。具体来说,我需要在 axis=2 中使用 np.repeat() 函数
在 numpy 中我的代码是:
original = np.random.rand(1000,1)
no_repeats = 10
big_original = np.repeat(np.expand_dims((5)*original, axis=2), no_repeats, axis=2)
我如何以 Numba 友好的方式重写它?
我尝试过使用np.dstack:
expanded_original = np.expand_dims((5)*original, axis=2)
big_original = np.dstack([expandedGradientMatrix]*no_repeats)
但是列表当然不是受支持的数据类型。我怎样才能以最有效的方式解决这个问题?
我不知道你到底想做什么,但我猜你想在
big_original
numba 编译函数中重现 @njit
数组。对吗?
我是这样:
@njit
def repeat_original(original, no_repeats):
big_original = np.zeros((*original.shape, no_repeats))
for i in range(big_original.shape[-1]):
big_original[...,i] = (5)*original
return big_original
repeat_original(original, no_repeats)
如果这不是您期望的答案,请尝试更好地说明您的问题(例如
expandedGradientMatrix
是什么)和您的预期输出。