Pyspark-一组两个日期列的UDAF函数,UDAF用于计算实际值和预测值之间的RMSE

问题描述 投票:1回答:1

在接下来的几年中,我将数据保存在pyspark数据框中。 week_start_dt是我开始进行预测的时间。而start_month是前12个月。

+--------------------+------------------+----------------------+----------------+
|     start_month    |     week_start_dt|           predictions|       actuals  |
+--------------------+------------------+----------------------+----------------+
|             2019-01|        2019-11-11|                    12|              11|
|             2018-12|        2019-11-11|                    13|              11|
|             2019-08|        2019-11-11|                     9|              11|
|             2019-11|        2019-11-11|                    12|              11|
|             2019-11|        2019-11-11|                  1970|            1440|
|             2019-11|        2019-11-11|                   478|             501|
+--------------------+------------------+----------------------+----------------+

我想在groupbystart_month上使用week_start_dt来计算RMSE。我认为这将需要一个用户定义的聚合函数。在大熊猫中与此类似:Python Dataframe: Calculating R^2 and RMSE Using Groupby on One Column

我使用以下代码来汇总分组依据的实际计数和预测计数。

df_startmonth_week = actuals_compare.groupby('start_month', 'week_start_dt').agg(f.sum('predictions'), f.sum('actuals'))

我在汇总步骤中如何更改以计算预测值与实际值之间的RMSE?我需要UDF来做到这一点吗?

这里是我在excel中制定的最终目标的示例

| week_start_dt | start_month | RMSE |
|---------------|-------------|------|
| 20-01-2020    | 2019-02     | 2345 |
| 20-01-2020    | 2019-03     | 2343 |
| 20-01-2020    | 2019-04     | 2341 |
| 20-01-2020    | 2019-05     | 2100 |
| 20-01-2020    | 2019-06     | 1234 |
apache-spark pyspark apache-spark-sql pyspark-sql
1个回答
1
投票

我看不到与problem here的区别,所以我将解决方案调整为略有不同的变量名称:

import pyspark.sql.functions as psf

def compute_RMSE(expected_col, actual_col):

  rmse = old_df.withColumn("squarederror",
                           psf.pow(psf.col(actual_col) - psf.col(expected_col),
                                   psf.lit(2)
                           ))
  .groupby('start_month', 'week_start_dt')
  .agg(psf.avg(psf.col("squarederror")).alias("mse"))
  .withColumn("rmse", psf.sqrt(psf.col("mse")))

  return(rmse)


compute_RMSE("predictions", "actuals")

告诉我,如果我错过了这个问题的细微差别

© www.soinside.com 2019 - 2024. All rights reserved.