让我们以 Pandas 数据框为例,它有两列
'date'
和 'price'
,其中 'date'
始终升序,而 'price'
是随机的,即
df = pd.DataFrame({
'date':['01/01/2019', '01/02/2019', '01/03/2019',
'01/04/2019', '01/05/2019', '01/06/2019',
'01/07/2019', '01/08/2019', '01/09/2019',
'01/10/2019'],
'price': [10, 2, 5, 4, 12, 8, 9, 19, 12, 3]
})
目标是再添加两列,
'next_date'
和 'next_price'
,其中
'next_date'
包含第一次出现大于当前价格的价格的日期,'next_price'
包含第一次出现大于当前价格的价格,即
date price next_date next_price
0 01/01/2019 10 01/05/2019 12
1 01/02/2019 2 01/03/2019 5
2 01/03/2019 5 01/05/2019 12
3 01/04/2019 4 01/05/2019 12
4 01/05/2019 12 01/08/2019 19
5 01/06/2019 8 01/07/2019 9
6 01/07/2019 9 01/08/2019 19
7 01/08/2019 19 NaN NaN
8 01/09/2019 12 NaN NaN
9 01/10/2019 3 NaN NaN
我已经测试了一些解决方案,这些解决方案可以满足我的要求,但性能非常差,并且由于真实的数据帧包含超过一百万行,因此它们是不切实际的。
这些是我的测试解决方案:
使用Pandasql
result = sqldf("SELECT l.date, l.price, min(r.date) as next_date " +
"from df as l left join df as r on (r.date > l.date " +
"and r.price > l.price) group by l.date, l.price order by l.date")
result = pd.merge(result ,df, left_on='next_date', right_on='date',
suffixes=('', '_next'), how='left')
print(result)
使用 Pandas 到 SQLite
df.to_sql('df', conn, index=False)
qry = "SELECT l.date, l.price, min(r.date) as next_date from df as " +
"l left join df as r on (r.date > l.date and r.price > l.price) " +
"group by l.date, l.price order by l.date"
result = pd.read_sql_query(qry, conn)
result = pd.merge(result ,df, left_on='next_date', right_on='date',
suffixes=('', '_next'), how='left')
print(result)
使用
apply
def find_next_price(row):
mask = (df['price'] > row['price']) & (df['date'] > row['date'])
if len(df[mask]):
return df[mask]['date'].iloc[0], df[mask]['price'].iloc[0]
else:
return np.nan, np.nan
df[['next_date', 'next_price']] = list(df.apply(find_next_price, axis=1))
print(df)
其中一些解决方案在只有 50,000 行时就开始失败, 虽然我需要对一百万行执行此任务。
注意: 有一个非常相似的问题here但答案的表现仍然不足。
由于您需要对大量行(1M+)执行此任务,因此使用
numpy
的传统方法可能不可行,尤其是当您的内存量有限时。在这里,我提出了一种使用基本算法计算的函数方法,您可以使用 numba's
即时编译器来编译此函数,以实现 C
的速度:
import numba
@numba.njit
def argmax(price: np.ndarray):
for i in range(len(price)):
idx = -1
for j in range(i + 1, len(price)):
if price[i] < price[j]:
idx = j
break
yield idx
idx = -1
i = np.array(list(argmax(df['price'].values)))
m = i != -1 # index is -1 if there's no next greater price
df.loc[m, 'next_date'] = df['date'].values[i[m]]
df.loc[m, 'next_price'] = df['price'].values[i[m]]
结果
date price next_date next_price
0 01/01/2019 10 01/05/2019 12.0
1 01/02/2019 2 01/03/2019 5.0
2 01/03/2019 5 01/05/2019 12.0
3 01/04/2019 4 01/05/2019 12.0
4 01/05/2019 12 01/08/2019 19.0
5 01/06/2019 8 01/07/2019 9.0
6 01/07/2019 9 01/08/2019 19.0
7 01/08/2019 19 NaN NaN
8 01/09/2019 12 NaN NaN
9 01/10/2019 3 NaN NaN
PS:解决方案已在 1M+ 行上进行了测试。