Numpy索引值拾取规则

问题描述 投票:0回答:1

让我们假设一个这样的程序:

x = np.arange(0, 12).reshape(2, 3, 2)
y = x[0:2, [0, 1], 0:2]
print(y.flatten())

此代码打印

[0 1 2 3 6 7 8 9]
。 当然,这些每个值都是从原始ndarray中提取出来的
x
。 我检查了原始
x
中每个值的索引,下面是它:

0: x[0, 0, 0]                                                                                                                                                                                                                                                                             
1: x[0, 0, 1]                                                                                                                                                                                                                                                                             
2: x[0, 1, 0]                                                                                                                                                                                                                                                                             
3: x[0, 1, 1]                                                                                                                                                                                                                                                                             
6: x[1, 0, 0]                                                                                                                                                                                                                                                                             
7: x[1, 0, 1]                                                                                                                                                                                                                                                                             
8: x[1, 1, 0]                                                                                                                                                                                                                                                                             
9: x[1, 1, 1]

查看前 4 行,

x[0, 0, 0]
x[0, 0, 1]
之前。之后,
x[0, 1, 0]
x[0, 1, 1]
即将到来。 因此,从 x 中获取值的顺序“规则”如下所示:首先,对第一个参数
0:2
进行循环。然后,在循环中,对第二个参数
[0, 1]
进行内部循环。最后,第三个参数
0:2
的最后一个循环发生在内部。伪代码如下所示:

indices = []
for i in 0:2 { # first arg
    for j in [0, 1] { # second arg
        for k in 0:2 { # third arg
            indices.append([i, j, k])
        }
    }
}

这也是我对numpy取值规则的最初理解。

但是下面的代码看起来不符合规则。

x = np.arange(0, 72).reshape(2, 3, 2, 3, 2)
x[0:2, 2, 1:2, [0, 1], 0:2]
print(y.flatten())

此代码打印

[30 31 66 67 32 33 68 69]
,索引为:

30: x[0, 2, 1, 0, 0]                                                                                                                                                                                                                                                                      
31: x[0, 2, 1, 0, 1]                                                                                                                                                                                                                                                                      
66: x[1, 2, 1, 0, 0]                                                                                                                                                                                                                                                                      
67: x[1, 2, 1, 0, 1]                                                                                                                                                                                                                                                                      
32: x[0, 2, 1, 1, 0]                                                                                                                                                                                                                                                                      
33: x[0, 2, 1, 1, 1]                                                                                                                                                                                                                                                                      
68: x[1, 2, 1, 1, 0]                                                                                                                                                                                                                                                                      
69: x[1, 2, 1, 1, 1]

如果上述规则正确,伪代码将如下所示:

for i in 0:2 {
    for j in 2 {
        for k in 1:2 {
            for l in [0:1] {
                for m in 0:2 {
                    indices.append([i, j, k, l, m])
                }
            }
        }
    }
}

然而,实际结果表明,第四个参数

[0, 1]
的循环发生在第一个参数
0:2
的循环之后。我想知道为什么会发生这种情况。

我认为我遗漏或误解了 numpy 高级索引规则的一些内容,但为什么会有这样的差异?

python numpy
1个回答
0
投票

第一种情况:

In [149]: x = np.arange(0, 12).reshape(2, 3, 2)
     ...: y = x[0:2, [0, 1], 0:2]

strides
显示了如何访问元素(以字节为单位)、最左边最大的元素以及在测试时迭代最外层元素:

In [150]: x.shape, x.strides
Out[150]: ((2, 3, 2), (24, 8, 4))

y
因为它有一个列表索引 [0,1],所以有不同的步幅顺序,尽管这不会扰乱扁平化。

In [151]: y.shape, y.strides
Out[151]: ((2, 2, 2), (8, 16, 4))

In [152]: y
Out[152]: 
array([[[0, 1],
        [2, 3]],

       [[6, 7],
        [8, 9]]])

In [153]: y.ravel()
Out[153]: array([0, 1, 2, 3, 6, 7, 8, 9])

在幕后

numpy
首先使用 [0,1] 进行索引,然后转换为正确的顺序。在大多数情况下,此步骤对用户来说是透明的。

In [154]: y.base
Out[154]: 
array([[[0, 1],
        [6, 7]],

       [[2, 3],
        [8, 9]]])

你的第二个例子,中间有一个切片。如文档所述,切片尺寸放在最后。

In [156]: x = np.arange(0, 72).reshape(2, 3, 2, 3, 2)
     ...: y = x[0:2, 2, 1:2, [0, 1], 0:2]

In [157]: y.shape, y.strides
Out[157]: ((2, 2, 1, 2), (16, 8, 8, 4))

我认为这里发生的事情与第一种情况相同,但它可以(明确地)转置这个

y

如果我用列表替换

1:2
切片:

In [158]: x = np.arange(0, 72).reshape(2, 3, 2, 3, 2)
     ...: y = x[0:2, 2, [1], [0, 1], 0:2]

In [159]: y.shape, y.strides
Out[159]: ((2, 2, 2), (8, 16, 4))

压平2个案例:

In [160]: x[0:2, 2, 1:2, [0, 1], 0:2].ravel()
Out[160]: array([30, 31, 66, 67, 32, 33, 68, 69])

In [161]: x[0:2, 2, [1], [0, 1], 0:2].ravel()
Out[161]: array([30, 31, 32, 33, 66, 67, 68, 69])

“工作”案例的基础与混合基本/高级案例相同,除了额外的 1 维尺寸:

In [162]: x[0:2, 2, [1], [0, 1], 0:2].base
Out[162]: 
array([[[30, 31],
        [66, 67]],

       [[32, 33],
        [68, 69]]])

In [163]: x[0:2, 2, 1:2, [0, 1], 0:2].base  # pure copy

In [164]: x[0:2, 2, 1:2, [0, 1], 0:2]
Out[164]: 
array([[[[30, 31]],

        [[66, 67]]],


       [[[32, 33]],

        [[68, 69]]]])

[158]

y
是 [162] 基数转置,就像第一个示例所做的那样。

当生成的形状全部不同时,混合基本/高级索引的混乱尺寸顺序更加明显。

我可以找到适当的文档部分,并且一些 SO 可以针对此问题进行重复删除。

© www.soinside.com 2019 - 2024. All rights reserved.