覆盖 np.ndarray __getitem__

问题描述 投票:0回答:1

目标

我正在处理二维时间序列数据,并且从不使用负索引。所以我想对 np.ndarray 进行子类化,轴 0 中的负数和越界索引将返回具有合适形状的 nan 增广矩阵。例如,

>>> test
>>> array([[0, 1, 2],
           [3, 4, 5]])

Target:
>>> test[-1:1]   # np.array([0, 1, 2])
>>> array(
[[nan, nan, nan],
[0, 1, 2]])

>>> test[1:4]   # np.array([3, 4, 5])
>>> array([
[3, 4, 5],
[nan, nan, nan],
[nan, nan, nan]])

理想的解决方案是覆盖 __getitem__,测试第一个切片是否包含负数或越界索引,更改它,传入 super().__getitem__,并将输出与适当的 nan 数组连接起来。

这是一个示例类,用于测试当我们索引 ndarray 时传入的内容。然而,输出很奇怪

class NoneNeg(np.ndarray):
    def __getitem__(self, index):
        print(key)
        return super(NoneNeg, self).__getitem__(index)

>>> test.view(NoneNeg)
>>>(-2, -3)
(-2, -2)
(-2, -1)
(-1, -3)
(-1, -2)
(-1, -1)
NoneNeg([[0, 1, 2],
         [3, 4, 5]])

问题

  1. 索引ndarray时会传入什么?例如,测试[:, 2] 和测试[2]。猜测可能是 (slice(None), 2) 和 (2,)。
  2. 为什么在查看的对象之前有一些负索引元组输出?看来我无法改变负索引,因为它无处不在。
python arrays numpy subclassing
1个回答
0
投票

当您对 ndarray 进行索引时,getitem方法会根据索引操作接收各种形式的索引元组。例如:

  • test[:, 2] 将翻译为 (slice(None, None, None), 2)

  • test[2] 确实是 (2,)

您在查看的对象之前看到负索引元组的原因是当您访问 ndarray 本身而不指定索引时,getitem的默认行为。当您创建 NoneNeg 类并查看 ndarray 时,首先打印 ndarray 的表示形式,显示您观察到的负索引元组。

要处理 NoneNeg 类中的负数或越界索引,您可以分析传递给 getitem 的索引,根据需要修改它们,然后将修改后的索引传递给超类的 getitem。您可以根据负索引或越界索引将输出与合适的 NaN 增强数组连接起来。

这是对 NoneNeg 类的示例修改,它处理轴 0 中的负索引并返回 NaN 增强数组:

import numpy as np

class NoneNeg(np.ndarray):
    def __getitem__(self, index):
        if isinstance(index, tuple):
            modified_index = list(index)
            if isinstance(index[0], int) and index[0] < 0:
                modified_index[0] = slice(None, index[0] + 1)
            result = super(NoneNeg, self).__getitem__(tuple(modified_index))
            
            # Create NaN-augmented array for negative or out-of-bounds indexing
            if isinstance(index[0], int) and index[0] < 0:
                nan_shape = (abs(index[0]),) + result.shape[1:]
                nan_array = np.full(nan_shape, np.nan)
                result = np.concatenate((nan_array, result), axis=0)
                
            return result
        else:
            return super(NoneNeg, self).__getitem__(index)

# Test
test = np.array([[0, 1, 2], [3, 4, 5]])
none_neg_test = test.view(NoneNeg)

print(none_neg_test[-1:1])  # Output: 
                            # array([[nan, nan, nan],
                            #        [0.,  1.,  2.]])

print(none_neg_test[1:4])   # Output:
                            # array([[3., 4., 5.],
                            #        [nan, nan, nan],
                            #        [nan, nan, nan]])

此修改检查第一个索引是否为负整数,适当调整索引,从超类中获取相应的切片,然后在必要时将其与 NaN 填充的数组连接起来。

© www.soinside.com 2019 - 2024. All rights reserved.