优化为 Polars 中的分割数据组分配索引

问题描述 投票:0回答:1

数据逻辑:

我有一个时间序列需要分成多个块。

假设这篇文章需要分成 3 个部分。我使用的数据是股票报价数据和每日价格。如果时间序列数据的长度是 3 个月,“分割”范围是 1 个月,那么应该有 3 块数据,每个月都用递增的整数标记。因此,时间序列中应该有 3 个部分,全部在一个数据框中。应该有一个名为

range_index
的列,从 0 开始,迭代直到 2。例如,如果数据是 1 月至 3 月的数据,则每个报价应标记为 0、1 或 2。0 表示 1 月,1 表示 2 月,3 月份数据为 2。

我希望对数据框中的每个符号都执行此操作。每个交易品种的

start_date
可能不相同,因此它应该具有鲁棒性,并根据
range_index
股票数据正确分配
symbol
值。

我做了什么:

我已经使用极性逻辑构建了一个函数,将一列添加到该数据帧上,但我认为可能有更快的方法来做到这一点。当我添加一些带有几年数据的符号时,执行速度会减慢到大约 3 秒。

我希望获得有关如何加速该功能的任何建议,甚至是一种新颖的方法。我知道基于行的操作在极性上比柱状上慢。如果有任何极地书呆子发现了明显的问题......请帮忙!

def add_range_index(
    data: pl.LazyFrame | pl.DataFrame, range_str: str
) -> pl.LazyFrame | pl.DataFrame:
    """
    Context: Toolbox || Category: Helpers || Sub-Category: Mandelbrot Channel Helpers || **Command: add_n_range**.

    This function is used to add a column to the dataframe that contains the
    range grouping for the entire time series.


    This function is used in `log_mean()`
    """  # noqa: W505
    range_str = _range_format(range_str)

    if "date" in data.columns:
        group_by_args = {
            "every": range_str,
            "closed": "left",
            "include_boundaries": True,
        }
        if "symbol" in data.columns:
            group_by_args["by"] = "symbol"
            symbols = (data.select("symbol").unique().count())["symbol"][0]

        grouped_data = (
            data.lazy()
            .set_sorted("date")
            .group_by_dynamic("date", **group_by_args)
            .agg(
                pl.col("adj_close").count().alias("n_obs")
            )  # using 'adj_close' as the column to sum
        )
    range_row = grouped_data.with_columns(
        pl.arange(0, pl.count()).over("symbol").alias("range_index")
    )
    ## WIP:
    # Extract the number of ranges the time series has

    # Initialize a new column to store the range index
    data = data.with_columns(pl.lit(None).alias("range_index"))

    # Loop through each range and add the range index to the original dataframe
    for row in range_row.collect().to_dicts():
        symbol = row["symbol"]
        start_date = row["_lower_boundary"]
        end_date = row["_upper_boundary"]
        range_index = row["range_index"]

        # Apply the conditional logic to each group defined by the 'symbol' column
        data = data.with_columns(
            pl.when(
                (pl.col("date") >= start_date)
                & (pl.col("date") < end_date)
                & (pl.col("symbol") == symbol)
            )
            .then(range_index)
            .otherwise(pl.col("range_index"))
            .over("symbol")  # Apply the logic over each 'symbol' group
            .alias("range_index")
        )

    return data


def _range_format(range_str: str) -> str:
    """
    Context: Toolbox || Category: Technical || Sub-Category: Mandelbrot Channel Helpers || **Command: _range_format**.

    This function formats a range string into a standard format.
    The return value is to be passed to `_range_days()`.

    Parameters
    ----------
    range_str : str
        The range string to format. It should contain a number followed by a
        range part. The range part can be 'day', 'week', 'month', 'quarter', or
        'year'. The range part can be in singular or plural form and can be
        abbreviated. For example, '2 weeks', '2week', '2wks', '2wk', '2w' are
        all valid.

    Returns
    -------
    str
        The formatted range string. The number is followed by an abbreviation of
        the range part ('d' for day, 'w' for week, 'mo' for month, 'q' for
        quarter, 'y' for year). For example, '2 weeks' is formatted as '2w'.

    Raises
    ------
    RangeFormatError
        If an invalid range part is provided.

    Notes
    -----
    This function is used in `log_mean()`
    """  # noqa: W505
    # Separate the number and range part
    num = "".join(filter(str.isdigit, range_str))

    # Find the first character after the number in the range string
    range_part = next((char for char in range_str if char.isalpha()), None)

    # Check if the range part is a valid abbreviation
    if range_part not in {"d", "w", "m", "y", "q"}:
        msg = f"`{range_str}` could not be formatted; needs to include d, w, m, y, q"
        raise HumblDataError(msg)

    # If the range part is "m", replace it with "mo" to represent "month"
    if range_part == "m":
        range_part = "mo"

    # Return the formatted range string
    return num + range_part

预期数据表:

PCT 股票代码也是如此。

python dataframe optimization indexing python-polars
1个回答
0
投票

您已经走在正确的轨道上了

pl.DataFrame.group_by_dynamic
。整个函数可以简化如下。

  1. 按代码和范围索引对数据进行分组(例如 1 个月周期)
  2. 聚合每个交易品种/范围索引组内的数据。
  3. range_index
    列添加到每个组。
  4. 删除 date 列(其中包含每个期间的开始日期),将其替换为 dates 列(其中包含该期间内的所有日期),然后分解新的 date 列以获得每个期间的一行又是原来的日期。
(
    df
    .set_sorted("symbol", "date")
    .group_by_dynamic(index_column="date", every="1mo", by="symbol")
    .agg(
        pl.col("date").alias("range_dates"),
        pl.exclude("date", "symbol"),
    )
    .with_columns(
        pl.int_range(0, pl.len()).alias("range_index").over("symbol")
    )
    .drop("date")
    .rename({"range_dates": "date"})
    .explode(pl.exclude("symbol", "range_index"))
)
© www.soinside.com 2019 - 2024. All rights reserved.