我尝试将dask.array.apply_along_axis
用于2D数组。但是,我的数组是一个dask数组,它总是抛出一个异常,如下所示:
Traceback (most recent call last):
File "D:/test/apply_along_axis_test.py", line 22, in <module>
b = da.apply_along_axis(lambda a: a[index_array], 1, source_array)
File "D:\Program Files\Python3\lib\site-packages\dask\array\routines.py", line 383, in apply_along_axis
test_result = np.array(func1d(test_data, *args, **kwargs))
File "D:/test/apply_along_axis_test.py", line 22, in <lambda>
b = da.apply_along_axis(lambda a: a[index_array], 1, source_array)
IndexError: index 1 is out of bounds for axis 0 with size 1
但是,当我将此方法应用于numpy.array
时。它可以成功运行。
示例代码如下:
source_array = np.random.randint(0, 10, (2, 4))
index_array = np.asarray([[0, 0], [1, 0], [2, 1], [3, 2]])
b = np.apply_along_axis(lambda a: a[index_array], 1, source_array)
print(b)
source_array = da.from_array(source_array)
b = da.apply_along_axis(lambda a: a[index_array], 1, source_array)
我可以成功打印b。但是,代码的最后一行将引发异常。我认为也许我应该使用map_partitions
这样的地图方法。但是,我在dask.array
中找不到这样的任何方法。
我认为应该通过定义shape
和dtype
来解决。您可以手动执行此操作,也可以使用make_meta
推断其内容:
In [55]: from dask.dataframe.utils import make_meta
In [56]: da.apply_along_axis(lambda a: a[index_array], 1, source_array,
...: shape=make_meta(source_array).shape,
...: dtype=make_meta(source_array).dtype).compute()
Out[56]:
array([[[2, 2],
[1, 2],
[1, 1],
[6, 1]],
[[1, 1],
[6, 1],
[9, 6],
[3, 9]]])
您也不是第一个遇到此问题的人:https://github.com/dask/dask/issues/3727