我是 numba 新手,并试图了解它是如何工作的。
我想转换大气模型中 4D 风矢量的投影。矢量维度为(时间、高度、纬度、经度)。由于重新投影向量的 cartopy 函数仅接受 2D(水平)向量场,因此我必须循环遍历时间和高度维度。由于这是一个非常大的数组(大小为 [48, 90, 1200, 1200]),我试图使用 numba 使嵌套循环更快,但没有成功。即使使用 numba 后,代码似乎也没有加速,并且达到了 2 小时的极限时间!是因为numba无法加速cartopy功能吗?
请参阅下面我的代码的精简版本。我做错了什么?
import cartopy.crs as ccrs
import numpy, xarray
from numba import jit, prange
def coord_ref_systems(cube):
global source_crs, target_crs
source_crs = ccrs.RotatedPole(pole_longitude=140,
pole_latitude=36,
globe=ccrs.Globe(semimajor_axis=6370000, semiminor_axis=6370000))
target_crs = ccrs.PlateCarree()
def vector_transform(u, v, lon, lat):
xx, yy = numpy.meshgrid(lon, lat)
u_t, v_t = target_crs.transform_vectors(source_crs, xx, yy, u, v)
return u_t, v_t
@jit(parallel=True)
def transformer(u, v, lon, lat):
u_new, v_new = numpy.empty(u.shape), numpy.empty(u.shape)
for ti in prange(len(u.shape[0])):
for zi in prange(len(u.shape[1])):
u_t, v_t = vector_transform(u[ti, zi, :, :].squeeze(), v[ti, zi, :, :].squeeze(), lon, lat)
u_new[ti, zi, :, :] = u_t
v_new[ti, zi, :, :] = v_t
return u_new, v_new
def main():
ds = xarray.open_mfdataset('wind_vectors.nc', chunks={'time':6, 'height':-1, 'latitude': 200, 'longitude': 200})
u = ds['U']
v = ds['V']
_u_, _v_ = transformer(u.values, v.values, u.longitude.values, u.latitude.values)
if __name__ == '__main__':
main()