如何使用Dask并行迭代和更新numpy数组

问题描述 投票:0回答:1

我有一个非常大的距离矩阵,我需要迭代每个值并在条件为真时更新距离。

这是我的 Pandas/Numpy 代码块:

dist_mat = pd.read_csv()
date_list = metadata['sample_collection_date'].values
numpy_arr = dist_mat.values
columns = dist_mat.columns.tolist()
col_index = 0
with alive_bar(dist_mat.shape[0]) as bar:
   for i in range(dist_mat.shape[0]):
        numpy_arr[i][i] += 0.1
        for j in range(col_index):
            if abs(np.timedelta64(date_list[i] - date_list[j], 'D')) <= 14:
                    numpy_arr[i][j] += 0.1
                    numpy_arr[j][i] += 0.1
        col_index += 1
        bar()

我尝试过使用 Dask,但它并不比我使用 Pandas/Numpy 快。我想知道什么是有助于并行处理此代码块的正确方法。

dist_mat = dd.read_csv(args.dist_file, sep='\t', skiprows=2, sample=10000000, assume_missing=True).set_index('#Sources')

date_list = metadata['sample_collection_date'].values
np_array = dist_mat.to_dask_array(lengths=True)
columns = dist_mat.columns.tolist()
col_index = 0
with alive_bar(dist_mat.shape[0].compute()) as bar:
    for i in range(dist_mat.shape[0].compute()):
        numpy_arr[i][i] += 0.1
        for j in range(col_index):
            if abs(np.timedelta64(date_list[i] - date_list[j], 'D')) <= 14:
                    numpy_arr[i][j] += 0.1
                    numpy_arr[j][i] += 0.1
        col_index += 1
        bar()
python pandas dataframe numpy dask
1个回答
0
投票

问题似乎是您当前的 Dask 代码,您在循环中调用

.compute()
。这对于计算时间和内存来说可能非常昂贵。

更有效的方法可能是使用

Numba
库,它可以 JIT 编译 python 代码,并可以利用多个内核来执行某些类型的操作 - 特别是
ufuncs

这是一个例子

import numba
import numpy as np
import pandas as pd

@numba.njit(parallel=True)
def update_matrix(dist_mat, date_list, threshold=14, increment=0.1):
    N = dist_mat.shape[0]
    for i in numba.prange(N):
        dist_mat[i, i] += increment
        for j in range(i):
            if abs(np.timedelta64(date_list[i] - date_list[j], 'D')) <= threshold:
                dist_mat[i, j] += increment
                dist_mat[j, i] += increment
    return dist_mat

# Load your data
metadata = pd.read_csv('metadata.csv')
dist_mat = pd.read_csv('dist_mat.csv')

date_list = metadata['sample_collection_date'].values
dist_mat = dist_mat.values

# Update matrix
dist_mat = update_matrix(dist_mat, date_list)
© www.soinside.com 2019 - 2024. All rights reserved.