如何创建spark udf用于插入float到INT,以及如何编写比我更好的逻辑

问题描述 投票:1回答:2

下面是我的Spark Dataframe我想做插值并为此编写一个Spark UDF我不知道如何编写更好的逻辑并从上面创建一个UDF

这用于转换Position_float并将其插值为整数,以将Position转换为适当的整数值

def dirty_fill(df, id_col, y_cols):
    from pyspark.sql import types as T
    df = df.withColumn('position_plus', (df.position_float + 0.5).cast(T.IntegerType()))
    df = df.withColumn('position_minus', (df.position_float - 0.5).cast(T.IntegerType()))
    df = df.withColumn('position', df.position_float.cast(T.IntegerType()))
    df1 = df.select([id_col, 'position_plus'] + y_cols).withColumnRenamed('position_plus', 'position')
    df2 = df.select([id_col, 'position_minus'] + y_cols).withColumnRenamed('position_minus', 'position')
    df3 = df.select([id_col, 'position'] + y_cols)
    df123 = df1.union(df2).union(df3).sort([id_col, 'position']).dropDuplicates([id_col, 'position'])
    return df123
y_cols = ['entry_temperature']
finish_mill_entry_filled = dirty_fill(finish_mill_entry, 'finish_mill_id', y_cols)

这是我的数据框样本

| Finishing_mill_id  | Sample  | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529         | 1       | 0.000000       | 1986.0     |
| 2015418529         | 2       | 2.192982       | 1997.0     |
| 2015418529         | 3       | 4.385965       | 2003.0     |
| 2018171498         | 445     | 495.535714     | 1643.0     |
| 2018171498         | 446     | 496.651786     | 1734.0     |
| 2018171498         | 447     | 497.767857     | 1748.0     |
| 2018171498         | 448     | 498.883929     | 1755.0     |

我需要将float插值为整数

我想要的是

| Finishing_mill_id  | Sample  | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529         | 1       | 0              | 1986.0     |
| 2015418529         | 2       | 1              | 1986       |
| 2015418529         | 3       | 2              | 1997.0     |
| 2015418529         | 4       | 3              | 1997       |
| 2015418529         | 5       | 4              | 2003.0     |
| 2018171498         | 445     | 496            | 1643.0     |
| 2018171498         | 446     | 497            | 1734.0     |
| 2018171498         | 447     | 498            | 1748.0     |
| 2018171498         | 448     | 499            | 1755.0     |

我需要一个spark user_defined函数来执行此操作,并且不应该丢失任何数据点,因为我的Position_float在0-500范围内我还需要注意每个点都没有丢失任何点。需要以适当的方式修改插值逻辑

为了说清楚,我说我的位置为0.000 2.19,但我没有数据表,但是当我这样做时我需要有1.00的位置..我需要位置1.00的值,即使数据没有那种线性插值我希望它有所帮助

python dataframe pyspark interpolation
2个回答
1
投票

1.窗口功能

您可以使用窗口函数填充间隙并插值。

让我们从一个示例数据帧开始:

import pyspark.sql.functions as psf
import pyspark.sql.types as pst
from pyspark.sql import Window
import numpy as np

df = spark.createDataFrame(
        [[float(t)/10., float(v)] for t, v in zip(np.random.randint(0, 1000, 20), np.random.randint(100, 200, 20))], 
        schema=pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position', 'value']])) \
    .withColumn('position_round', psf.round('position'))

        +--------+-----+--------------+
        |position|value|position_round|
        +--------+-----+--------------+
        |    68.5|121.0|          69.0|
        |    76.3|126.0|          76.0|
        |    88.3|150.0|          88.0|
        |    59.0|197.0|          59.0|
        |    20.7|119.0|          21.0|
        |     0.1|167.0|           0.0|
        |    20.1|177.0|          20.0|
        |    81.9|199.0|          82.0|
        |    63.6|163.0|          64.0|
        |    32.4|115.0|          32.0|
        |    43.6|130.0|          44.0|
        |    11.9|175.0|          12.0|
        |    68.2|176.0|          68.0|
        |    28.9|184.0|          29.0|
        |    46.3|199.0|          46.0|
        |     9.7|155.0|          10.0|
        |    57.8|163.0|          58.0|
        |    83.6|173.0|          84.0|
        |    16.2|169.0|          16.0|
        |    87.1|127.0|          87.0|
        +--------+-----+--------------+

为了填补空白,我们将创建一系列整数:

start, end = list(df.agg(psf.min('position_round'), psf.max('position_round')).collect()[0])
pos_df = spark.range(start=start, end=end, step=1) \
    .withColumnRenamed('id', 'position_round')

现在我们可以加入两个数据帧:

w1 = Window.orderBy('position_round')
w2 = Window.partitionBy('group').orderBy('position_round')

df_resample = df \
    .select(
        '*', 
        psf.lead('position_round', 1).over(w1).alias('next_position'), 
        psf.lead('value', 1).over(w1).alias('next_value')) \
    .join(pos_df, on='position_round', how='right') \
    .withColumn('group', psf.sum((~psf.isnull('position')).cast('int')).over(w1)) \
    .select(
        '*', 
        (psf.row_number().over(w2) - 1).alias('i'), 
        psf.first(psf.col('next_position') - psf.col('position_round')).over(w2).alias('dx'), 
        psf.first('value').over(w2).alias('value0'), 
        psf.first(psf.col('next_value') - psf.col('value')).over(w2).alias('dy')) \
    .withColumn(
        'value_round', 
        psf.when((psf.col('dx') > 0) | psf.isnull('next_value'), psf.col('value0') + psf.col('i') * psf.col('dy') / psf.col('dx')) \
            .otherwise(psf.col('value')))
  • 第一个窗口函数是存储next_valuenext_position以便以后能够计算我们的dxdy
  • 然后,我们需要使用不同的group id识别每个间隙,以便我们可以插入每个不同线性段的值
  • 最后但并非最不重要的是,我们汇集了我们需要的所有元素: 差距:dx delta值:dy 缺口i中的当前行索引

我们现在可以计算value_round,在value位置插入position_round

        +--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
        |position_round|position|value|next_position|next_value|group|  i|  dx|value0|   dy|value_round|
        +--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
        |             0|     0.1|167.0|         10.0|     155.0|    1|  0|10.0| 167.0|-12.0|      167.0|
        |             1|    null| null|         null|      null|    1|  1|10.0| 167.0|-12.0|      165.8|
        |             2|    null| null|         null|      null|    1|  2|10.0| 167.0|-12.0|      164.6|
        |             3|    null| null|         null|      null|    1|  3|10.0| 167.0|-12.0|      163.4|
        |             4|    null| null|         null|      null|    1|  4|10.0| 167.0|-12.0|      162.2|
        |             5|    null| null|         null|      null|    1|  5|10.0| 167.0|-12.0|      161.0|
        |             6|    null| null|         null|      null|    1|  6|10.0| 167.0|-12.0|      159.8|
        |             7|    null| null|         null|      null|    1|  7|10.0| 167.0|-12.0|      158.6|
        |             8|    null| null|         null|      null|    1|  8|10.0| 167.0|-12.0|      157.4|
        |             9|    null| null|         null|      null|    1|  9|10.0| 167.0|-12.0|      156.2|
        |            10|     9.7|155.0|         12.0|     175.0|    2|  0| 2.0| 155.0| 20.0|      155.0|
        |            11|    null| null|         null|      null|    2|  1| 2.0| 155.0| 20.0|      165.0|
        |            12|    11.9|175.0|         16.0|     169.0|    3|  0| 4.0| 175.0| -6.0|      175.0|
        |            13|    null| null|         null|      null|    3|  1| 4.0| 175.0| -6.0|      173.5|
        |            14|    null| null|         null|      null|    3|  2| 4.0| 175.0| -6.0|      172.0|
        |            15|    null| null|         null|      null|    3|  3| 4.0| 175.0| -6.0|      170.5|
        |            16|    16.2|169.0|         20.0|     177.0|    4|  0| 4.0| 169.0|  8.0|      169.0|
        |            17|    null| null|         null|      null|    4|  1| 4.0| 169.0|  8.0|      171.0|
        |            18|    null| null|         null|      null|    4|  2| 4.0| 169.0|  8.0|      173.0|
        |            19|    null| null|         null|      null|    4|  3| 4.0| 169.0|  8.0|      175.0|
        +--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+

2. UDF

如果您不想使用窗口函数,可以编写一个UDF来在python中进行插值,然后返回一个(位置,值)元组的数组:

def interpolate(pos, next_pos, value, next_value):
    if pos == next_pos or next_value is None:
        return [(pos, value)]
    return [[pos + i, value + i * (next_value - value) / (next_pos - pos)] for i in range(int(next_pos - pos))]
interpolate_udf = psf.udf(interpolate, pst.ArrayType(pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position_round', 'value_round']])))

请注意,元组的类型为StructType,以便更容易将元组“扁平”为列。

w1 = Window.orderBy('position_round')
df_udf = df \
    .select(
        '*', 
        psf.lead('position_round', 1).over(w1).alias('next_position'), 
        psf.lead('value', 1).over(w1).alias('next_value')) \
    .withColumn('tmp', psf.explode(interpolate_udf('position_round', 'next_position', 'value', 'next_value'))) \
    .select('*', 'tmp.*').drop('tmp')

这是我们得到的:

        +--------+-----+--------------+-------------+----------+--------------+----------+
        |position|value|position_round|next_position|next_value|position_round|value_round|
        +--------+-----+--------------+-------------+----------+--------------+----------+
        |     0.1|167.0|           0.0|         10.0|     155.0|           0.0|     167.0|
        |     0.1|167.0|           0.0|         10.0|     155.0|           1.0|     165.8|
        |     0.1|167.0|           0.0|         10.0|     155.0|           2.0|     164.6|
        |     0.1|167.0|           0.0|         10.0|     155.0|           3.0|     163.4|
        |     0.1|167.0|           0.0|         10.0|     155.0|           4.0|     162.2|
        |     0.1|167.0|           0.0|         10.0|     155.0|           5.0|     161.0|
        |     0.1|167.0|           0.0|         10.0|     155.0|           6.0|     159.8|
        |     0.1|167.0|           0.0|         10.0|     155.0|           7.0|     158.6|
        |     0.1|167.0|           0.0|         10.0|     155.0|           8.0|     157.4|
        |     0.1|167.0|           0.0|         10.0|     155.0|           9.0|     156.2|
        |     9.7|155.0|          10.0|         12.0|     175.0|          10.0|     155.0|
        |     9.7|155.0|          10.0|         12.0|     175.0|          11.0|     165.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          12.0|     175.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          13.0|     173.5|
        |    11.9|175.0|          12.0|         16.0|     169.0|          14.0|     172.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          15.0|     170.5|
        |    16.2|169.0|          16.0|         20.0|     177.0|          16.0|     169.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          17.0|     171.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          18.0|     173.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          19.0|     175.0|
        +--------+-----+--------------+-------------+----------+--------------+----------+

0
投票

只需使用round并输入IntegerType

from pyspark.sql import functions as F
from pyspark.sql import types as T

df = df.withColumn('Position_float', F.round(F.col('Position_float')).cast(T.IntegerType()))
© www.soinside.com 2019 - 2024. All rights reserved.