对于size =(M,N)的大于内存的dask数组:如何从chunk =(1,N)重新chunk到chunk =(M,1)?

问题描述 投票:2回答:2

为了例如沿着整个轴应用在Numpy / Numba中编码的IIR-Filter,我需要用size=(M, N)chunks=(m0, n0) dask-array从chunks=(m1, N)重新组合到m1 < m0

由于Dask避免重复任务,在rechunk-split / rechunk-merge期间,它将在内存中具有值(m0, N)(x 2?)的数据。有没有一种优化图形的方法来避免这种行为?

我知道在哪里可以找到关于手动优化Dask图的信息。但有没有办法调整调度策略以允许重复任务或(自动)重新排列图形以便在此重新安排期间最小化内存使用?

这是一个最小的例子(对于chunks=(M, 1)chunks=(1, N)的极端情况):

from dask import array as da
from dask.distributed import Client

# limit memory to 4 GB
client = Client(memory_limit=4e9)

# Create 80 GB random array with chunks=(M, 1)
arr = da.random.uniform(-1, 1, size=(1e5, 1e5), chunks=(1e5, 1))

# Compute mean (This works!)
arr.mean().compute()

# Rechunk to chunks=(1, N)
arr = arr.rechunk((1, 1e5))

# Compute mean (This hits memory limit!)
arr.mean().compute()
dask dask-distributed
2个回答
2
投票

不幸的是,在最糟糕的情况下,您需要在获得单个输出块之前计算每个输入块。

Dask的重新组合操作是不错的,并且它们会在过渡期间将内容重新组合成中间大小的块,因此这可能会在不完整的内存中运行,但是你肯定会将内容写入磁盘。

简而言之,原则上你不应该做任何额外的事情。理论上,Dask的重新编程算法应该处理这个问题。如果你想,你可以使用threshold=block_size_limit=关键字来重新安排。


0
投票

block_size_limit=关键字导致了一种解决方案。

(下面,我使用一个较小的阵列,因为我没有留下80GB的磁盘溢出。)

from dask import array as da
from dask.distributed import Client

# limit memory to 1 GB
client = Client(n_workers=1, threads_per_worker=1, memory_limit=1e9)

# Create 3.2 GB array
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))

# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph")  # 2000 nodes

# Compute
print(arr.mean().compute())  # Takes 11.9 seconds. Doesn't spill.

# re-create array and rechunk with block_size_limit=1e3
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4), block_size_limit=1e3)

# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph")  # 32539 nodes

# Compute
print(arr.mean().compute())  # Takes 140 seconds, spills ~5GB to disk.

# re-create array and rechunk with default kwargs
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4))

# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph")  # 9206 nodes

# Compute
print(arr.mean().compute())  # Worker dies at 95% memory use
© www.soinside.com 2019 - 2024. All rights reserved.