具有多重处理功能的Polars map_batches UDF

问题描述 投票:0回答:1

我想应用一个

numba UDF
,它为
df
中的每个组生成相同的长度向量:

import numba

df = pl.DataFrame(
    {
        "group": ["A", "A", "A", "B", "B", "B"],
        "index": [1, 3, 5, 1, 4, 8],
    }
)

@numba.jit(nopython=True)
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0
    
    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0
            
    return result

df.with_columns(
    pl.col("index")
    .map_batches(
        lambda x: UDF(x.to_numpy(), 5)
        )
    .over("group")
    .cast(pl.UInt8)
    .alias("udf")
    )

受到这篇文章的启发,其中引入了

multi-processing
应用程序。然而,在上面的例子中,我使用
over
窗口函数来应用 UDF。是否有一种通过并行化上述执行的有效方法?

预期输出:

shape: (6, 3)
┌───────┬───────┬─────┐
│ group ┆ index ┆ udf │
│ ---   ┆ ---   ┆ --- │
│ str   ┆ i64   ┆ u8  │
╞═══════╪═══════╪═════╡
│ A     ┆ 1     ┆ 0   │
│ A     ┆ 3     ┆ 0   │
│ A     ┆ 5     ┆ 1   │
│ B     ┆ 1     ┆ 0   │
│ B     ┆ 4     ┆ 1   │
│ B     ┆ 8     ┆ 1   │
└───────┴───────┴─────┘
python parallel-processing python-polars
1个回答
1
投票

这里是如何使用 + 使用 numba 的并行化功能来做到这一点的示例:

from numba import njit, prange


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)
print(df)

打印:

shape: (9, 3)
┌───────┬───────┬─────────┐
│ group ┆ index ┆ new_udf │
│ ---   ┆ ---   ┆ ---     │
│ str   ┆ i64   ┆ u8      │
╞═══════╪═══════╪═════════╡
│ A     ┆ 1     ┆ 0       │
│ A     ┆ 3     ┆ 0       │
│ A     ┆ 5     ┆ 1       │
│ B     ┆ 1     ┆ 0       │
│ B     ┆ 4     ┆ 1       │
│ B     ┆ 8     ┆ 1       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 4     ┆ 1       │
└───────┴───────┴─────────┘

基准:

from timeit import timeit

import numpy as np
import polars as pl
from numba import njit, prange


def get_df(N, n):
    assert N % n == 0

    df = pl.DataFrame(
        {
            "group": [f"group_{i}" for i in range(N // n) for _ in range(n)],
            "index": np.random.randint(1, 5, size=N, dtype="uint64"),
        }
    )
    return df


@njit
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0

    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0

    return result


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result


def get_udf_polars(df):
    return df.with_columns(
        pl.col("index")
        .map_batches(lambda x: UDF(x.to_numpy(), 5))
        .over("group")
        .cast(pl.UInt8)
        .alias("udf")
    )


df = get_df(3 * 33_333, 3)  # 100_000 values, length of groups 3

df = get_udf_polars(df)

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)

assert np.allclose(df["udf"].to_numpy(), df["new_udf"].to_numpy())


t1 = timeit("get_udf_polars(df)", number=1, globals=globals())
t2 = timeit(
    'df.with_columns(pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5)))',
    number=1,
    globals=globals(),
)

print(t1)
print(t2)

在我的机器上打印(AMD 5700x):

2.7000599699968006
0.00025866299984045327

100_000_000 行/组 3 需要

0.06319052699836902
(使用
parallel=False
这需要
0.2159650030080229

© www.soinside.com 2019 - 2024. All rights reserved.