xarray 中的高级索引,非正交

问题描述 投票:0回答:2

我想从

DataArray
中选择几个点,其方式与 numpy (
arr_np[:, x_idxs, y_idxs]
) 中的操作类似,但
xarray
中的高级索引似乎与
numpy
中的高级索引不同。

我认为这是因为

xarray
始终执行正交索引(如文档的 “矢量化索引” 部分中所述。

假设我有以下数组:

from string import ascii_lowercase, ascii_uppercase

import xarray as xr
import numpy as np


sizes = {"band": 4, "x": 5, "y": 6}
shape = tuple(sizes.values())
dims = tuple(sizes.keys())

arr_np = np.arange(np.prod(shape)).reshape(shape)
arr_xr = xr.DataArray(
    data=arr_np,
    dims=dims,
    coords={
        "band": np.arange(100, 100 + sizes["band"]),
        "x": list(ascii_lowercase[: sizes["x"]]),  # abcde...
        "y": list(ascii_uppercase[: sizes["y"]]),  # ABCDE...
    },
)
>>> arr_xr
<xarray.DataArray (band: 4, x: 5, y: 6)>
array([[[  0,   1,   2,   3,   4,   5],
        [  6,   7,   8,   9,  10,  11],
        [ 12,  13,  14,  15,  16,  17],
        [ 18,  19,  20,  21,  22,  23],
        [ 24,  25,  26,  27,  28,  29]],

       ...

       [[ 90,  91,  92,  93,  94,  95],
        [ 96,  97,  98,  99, 100, 101],
        [102, 103, 104, 105, 106, 107],
        [108, 109, 110, 111, 112, 113],
        [114, 115, 116, 117, 118, 119]]])
Coordinates:
  * band     (band) int64 100 101 102 103
  * x        (x) <U1 'a' 'b' 'c' 'd' 'e'
  * y        (y) <U1 'A' 'B' 'C' 'D' 'E' 'F'

我想采样以下几点(这里以我能想到的任何可以使用的方式进行格式化):

points = [("a", "C"), ("d", "D"), ("e", "A")]
points_idxs = [(0, 2), (3, 3), (4, 0)]

xs, ys = map(list, zip(*points))  # ['a', 'd', 'e'] and ['C', 'D', 'A']
x_idxs, y_idxs = map(list, zip(*points_idxs))  # [0, 3, 4] and [2, 3, 0]

在 numpy 中我可以这样做:

expected = arr_np[:, x_idxs, y_idxs]
assert expected.shape == (sizes["band"], len(points))
>>> expected
array([[  2,  21,  24],
       [ 32,  51,  54],
       [ 62,  81,  84],
       [ 92, 111, 114]])

但是在

DataArray
上这样做不会产生相同的输出:

>>> # Same output for the 3 ways
>>> arr_xr.loc[:, xs, ys]
>>> arr_xr.sel(x=xs, y=ys)
>>> arr_xr[:, x_idxs, y_idxs]
<xarray.DataArray (band: 4, x: 3, y: 3)>
array([[[  2,   3,   0],
        [ 20,  21,  18],
        [ 26,  27,  24]],

       [[ 32,  33,  30],
        [ 50,  51,  48],
        [ 56,  57,  54]],

       [[ 62,  63,  60],
        [ 80,  81,  78],
        [ 86,  87,  84]],

       [[ 92,  93,  90],
        [110, 111, 108],
        [116, 117, 114]]])
Coordinates:
  * band     (band) int64 100 101 102 103
  * x        (x) <U1 'a' 'd' 'e'
  * y        (y) <U1 'C' 'D' 'A'

如何获得预期的输出?

numpy python-xarray
2个回答
0
投票

我设法使用

.stack()
实现了它,但只使用了
.sel()
方法,所以没有使用基于索引的索引来实现它。

>>> points
[('a', 'C'), ('d', 'D'), ('e', 'A')]
>>> arr_xr.stack(pos=("x", "y")).loc[:, points]  # or
>>> arr_xr.stack(pos=("x", "y")).sel(pos=points)
<xarray.DataArray (band: 4, pos: 3)>
array([[  2,  21,  24],
       [ 32,  51,  54],
       [ 62,  81,  84],
       [ 92, 111, 114]])
Coordinates:
  * band     (band) int64 100 101 102 103
  * pos      (pos) object MultiIndex
  * x        (pos) <U1 'a' 'd' 'e'
  * y        (pos) <U1 'C' 'D' 'A'

不适用于基于索引的索引,所以问题仍然存在:

>>> arr_xr.stack(pos=("x", "y")).isel(pos=points_idxs)  # or
>>> arr_xr.stack(pos=("x", "y"))[:, points_idxs]
IndexError: Unlabeled multi-dimensional array cannot be used for indexing: pos

0
投票

最终成功地使用@hpaulj的提示,使用

DataArray
作为索引器,在
result
数组上引入了一个新的维度“pos”。

使用问题中定义的变量。

基于索引的索引

# x_idxs, y_idxs = [0, 3, 4], [2, 3, 0]

# Using DataArray object as indexer
x_idxs_da = xr.DataArray(x_idxs, dims=["pos"])
y_idxs_da = xr.DataArray(y_idxs, dims=["pos"])

result = arr_xr.isel(x=x_idxs_da, y=y_idxs_da)
assert np.all(result.data == expected)

result = arr_xr[:, x_idxs_da, y_idxs_da]
assert np.all(result.data == expected)

# Or annotated tuple to be implicitly converted to DataArray indexer
result = arr_xr.isel(x=("pos", x_idxs), y=("pos", y_idxs))
assert np.all(result.data == expected)

result = arr_xr[:, ("pos", x_idxs), ("pos", y_idxs)]
assert np.all(result.data == expected)

基于坐标的索引

# xs, ys = ['a', 'd', 'e'], ['C', 'D', 'A']

# Using DataArray object as indexer
xs_da = xr.DataArray(xs, dims=["pos"])
ys_da = xr.DataArray(ys, dims=["pos"])

result = arr_xr.sel(x=xs_da, y=ys_da)
assert np.all(result.data == expected)

result = arr_xr.loc[:, xs_da, ys_da]
assert np.all(result.data == expected)

# Can't use implicit conversion
# result = arr_xr.sel(x=("pos", xs), y=("pos", ys))  # InvalidIndexError
# result = arr_xr.loc[:, ("pos", xs), ("pos", ys)]  # InvalidIndexError
© www.soinside.com 2019 - 2024. All rights reserved.