将cumsum()的输出转换为xarray中的二进制数组。

问题描述 投票:0回答:1

我有一个3D x数组,计算特定时间段的累计和,我想检测哪些时间段满足某个条件(并设为1),哪些时间段不满足这个条件(设为0)。我将用下面的代码来解释。

import pandas as pd
import xarray as xr
import numpy as np

# Create demo x-array
data = np.random.rand(20, 5, 5)
times = pd.date_range('2000-01-01', periods=20)
lats = np.arange(10, 0, -2)
lons = np.arange(0, 10, 2)
data = xr.DataArray(data, coords=[times, lats, lons], dims=['time', 'lat', 'lon'])
data.values[6:12] = 0 # Ensure some values are set to zero so that the cumsum can reset between valid time steps
data.values[18:] = 0

# This creates an xarray whereby the cumsum is calculated but resets each time a zero value is found
cumulative = data.cumsum(dim='time')-data.cumsum(dim='time').where(data.values == 0).ffill(dim='time').fillna(0)

print(cumulative[:,0,0])

>>> <xarray.DataArray (time: 20)>
array([0.13395 , 0.961934, 1.025337, 1.252985, 1.358501, 1.425393, 0.      ,
       0.      , 0.      , 0.      , 0.      , 0.      , 0.366988, 0.896463,
       1.728956, 2.000537, 2.316263, 2.922798, 0.      , 0.      ])
Coordinates:
  * time     (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20
    lat      int64 10
    lon      int64 0

打印语句显示,每次在时间维度上遇到0的时候,累计总和都会重置,我需要一个解决方案来识别,哪些时间段满足了某个条件(设置为1),哪些时间段不满足这个条件(设置为0)。我需要一个解决方案来识别,两个周期中哪个周期超过了2的值,并转换为二进制数组来确认哪里满足条件。

所以我的预期输出将是(对于这个具体的例子)。

<xarray.DataArray (time: 20)>
array([0.      , 0.      , 0.      , 0.      , 0.      , 0.     , 0.     ,
       0.      , 0.      , 0.      , 0.      , 0.      , 1.     , 1.     ,
       1.      , 1.      , 1.      , 1.      , 0.      , 0.     ])
python numpy python-xarray cumsum
1个回答
0
投票

使用一些遮挡和回填功能解决了这个问题。

# make something to put results in
out = xr.full_like(cumulative, fill_value=0.0)

# find the points which have met the criteria
out.values[cumulative.values > 3] = 1
# fill the other valid sections over 0, with nans so we can fill them
out.values[(cumulative.values>0) & (cumulative.values<3)] = np.nan

# backfill it, so the ones that have not reached 2 are filled with 0
# and the ones that have are filled with 1
out_ds = out.bfill(dim='time').fillna(1)

print ('Cumulative array:')
print (cumulative.values[:,0,0])
print (' ')
print ('Binary array')
print (out_ds.values[:,0,0])
© www.soinside.com 2019 - 2024. All rights reserved.