我有一个包含这样的数据的RDD:(downloadId: String, date: LocalDate, downloadCount: Int)
。 date和download-id是唯一的,下载次数是日期。
我一直在努力实现的是获取下载ID在所有下载ID中排名前100位的连续天数(从当前日期开始倒退)。因此,如果给定的下载在今天,昨天和前一天排在前100位,那么它的连胜将是3。
在SQL中,我想这可以使用窗口函数来解决。我见过类似的问题。 How to add a running count to rows in a 'streak' of consecutive days
(我对Spark很新,但不确定如何映射 - 减少RDD甚至开始解决这样的问题。)
更多信息,日期是过去30天,每天有大约独特的4M download -ids。
我建议你使用DataFrames,因为它们比RDD更容易使用。 Leo的答案较短,但我无法找到前100个下载的过滤位置,所以我也决定发布我的答案。它不依赖于窗口函数,但它受到过去想要连续的天数的约束。既然你说你只使用过去30天的数据,那应该不是问题。
作为第一步,我编写了一些代码来生成类似于您描述的DF。您不需要运行第一个块(如果这样做,减少行数,除非您有一个群集可以尝试它,它对内存很重)。您可以看到如何将RDD(theData
)转换为DF(baseData
)。你应该像我一样为它定义一个模式。
import java.time.LocalDate
import scala.util.Random
val maxId = 10000
val numRows = 15000000
val lastDate = LocalDate.of(2017, 12, 31)
// Generates the data. As a convenience for working with Dataframes, I converted the dates to epoch days.
val theData = sc.parallelize(1.to(numRows).map{
_ => {
val id = Random.nextInt(maxId)
val nDownloads = Random.nextInt((id / 1000 + 1))
Row(id, lastDate.minusDays(Random.nextInt(30)).toEpochDay, nDownloads)
}
})
//Working with Dataframes is much simples, so I'll generate a DF named baseData from the RDD
val schema = StructType(
StructField("downloadId", IntegerType, false) ::
StructField("date", LongType, false) ::
StructField("downloadCount", IntegerType, false) :: Nil)
val baseData = sparkSession.sqlContext.createDataFrame(theData, schema)
.groupBy($"downloadId", $"date")
.agg(sum($"downloadCount").as("downloadCount"))
.cache()
现在,您可以在名为baseData
的DF中获得所需的数据。下一步是将它限制在每天的前100名 - 你应该在进行任何额外的重大转换之前丢弃你没有的数据。
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
def filterOnlyTopN(data: DataFrame, n: Int = 100): DataFrame = {
// For each day in the data, let's find the cutoff # of downloads to make it into the top N
val getTopNCutoff = udf((downloads: Seq[Long]) => {
val reverseSortedDownloads = downloads.sortBy{- _ }
if (reverseSortedDownloads.length >= n)
reverseSortedDownloads.drop(n - 1).head
else
reverseSortedDownloads.last
})
val topNLimitsByDate = data.groupBy($"date").agg(collect_set($"downloadCount").as("downloads"))
.select($"date", getTopNCutoff($"downloads").as("cutoff"))
// And then, let's throw away the records below the top 100
data.join(topNLimitsByDate, Seq("date"))
.filter($"downloadCount" >= $"cutoff")
.drop("cutoff", "downloadCount")
}
val relevantData = filterOnlyTopN(baseData)
既然你有relevantData
DF只有你需要的数据,你可以计算它们的条纹。我已经将没有条纹的ID留下条纹0,你可以使用streaks.filter($"streak" > lit(0))
过滤掉那些。
def getStreak(df: DataFrame, fromDate: Long): DataFrame = {
val calcStreak = udf((dateList: Seq[Long]) => {
if (!dateList.contains(fromDate))
0
else {
val relevantDates = dateList.sortBy{- _ } // Order the dates descending
.dropWhile(_ != fromDate) // And drop everything until we find the starting day we are interested in
if (relevantDates.length == 1) // If there's only one day left, it's a one day streak
1
else // Otherwise, let's count the streak length (this works if no dates are left, too - but not with only 1 day)
relevantDates.sliding(2) // Take days by pairs
.takeWhile{twoDays => twoDays(1) == twoDays(0) - 1} // While the pair is of consecutive days
.length+1 // And the streak will be the number of consecutive pairs + 1 (the initial day of the streak)
}
})
df.groupBy($"downloadId").agg(collect_list($"date").as("dates")).select($"downloadId", calcStreak($"dates").as("streak"))
}
val streaks = getStreak(relevantData, lastDate.toEpochDay)
streaks.show()
+------------+--------+
| downloadId | streak |
+------------+--------+
| 8086 | 0 |
| 9852 | 0 |
| 7253 | 0 |
| 9376 | 0 |
| 7833 | 0 |
| 9465 | 1 |
| 7880 | 0 |
| 9900 | 1 |
| 7993 | 0 |
| 9427 | 1 |
| 8389 | 1 |
| 8638 | 1 |
| 8592 | 1 |
| 6397 | 0 |
| 7754 | 1 |
| 7982 | 0 |
| 7554 | 0 |
| 6357 | 1 |
| 7340 | 0 |
| 6336 | 0 |
+------------+--------+
在那里你有streaks
DF和你需要的数据。
在列出的PostgreSQL链接中使用类似的方法,您也可以在Spark中应用Window
函数。 Spark的DataFrame API没有java.time.LocalDate
的编码器,因此您需要将其转换为java.sql.Date
。
以下是步骤:首先,从RDD转换为支持日期格式的DataFrame;接下来,创建一个UDF
来计算baseDate
,它需要一个日期和每个id按时间顺序的行号(使用Window函数生成)作为参数。另一个Window函数用于计算per-id-baseDate行号,这是所需的条纹值:
import java.time.LocalDate
val rdd = sc.parallelize(Seq(
(1, LocalDate.parse("2017-12-13"), 2),
(1, LocalDate.parse("2017-12-16"), 1),
(1, LocalDate.parse("2017-12-17"), 1),
(1, LocalDate.parse("2017-12-18"), 2),
(1, LocalDate.parse("2017-12-20"), 1),
(1, LocalDate.parse("2017-12-21"), 3),
(2, LocalDate.parse("2017-12-15"), 2),
(2, LocalDate.parse("2017-12-16"), 1),
(2, LocalDate.parse("2017-12-19"), 1),
(2, LocalDate.parse("2017-12-20"), 1),
(2, LocalDate.parse("2017-12-21"), 2),
(2, LocalDate.parse("2017-12-23"), 1)
))
val df = rdd.map{ case (id, date, count) => (id, java.sql.Date.valueOf(date), count) }.
toDF("downloadId", "date", "downloadCount")
def baseDate = udf( (d: java.sql.Date, n: Long) =>
new java.sql.Date(new java.util.Date(d.getTime).getTime - n * 24 * 60 * 60 * 1000)
)
import org.apache.spark.sql.expressions.Window
val dfStreak = df.withColumn("rowNum", row_number.over(
Window.partitionBy($"downloadId").orderBy($"date")
)
).withColumn(
"baseDate", baseDate($"date", $"rowNum")
).select(
$"downloadId", $"date", $"downloadCount", row_number.over(
Window.partitionBy($"downloadId", $"baseDate").orderBy($"date")
).as("streak")
).orderBy($"downloadId", $"date")
dfStreak.show
+----------+----------+-------------+------+
|downloadId| date|downloadCount|streak|
+----------+----------+-------------+------+
| 1|2017-12-13| 2| 1|
| 1|2017-12-16| 1| 1|
| 1|2017-12-17| 1| 2|
| 1|2017-12-18| 2| 3|
| 1|2017-12-20| 1| 1|
| 1|2017-12-21| 3| 2|
| 2|2017-12-15| 2| 1|
| 2|2017-12-16| 1| 2|
| 2|2017-12-19| 1| 1|
| 2|2017-12-20| 1| 2|
| 2|2017-12-21| 2| 3|
| 2|2017-12-23| 1| 1|
+----------+----------+-------------+------+