Pandas ExtensionArray 的简单示例

问题描述 投票:0回答:2

在我看来,Pandas

ExtensionArray
s 是一种简单的例子来开始真正有帮助的情况。然而,我还没有在任何地方找到足够简单的例子。

创建一个
ExtensionArray

要创建

ExtensionArray
,您需要

Pandas 文档中还有一个 部分,其中有简要概述。

实现示例

有很多实现示例:

问题

尽管研究了以上所有内容,我仍然发现扩展数组很难理解。所有示例都有很多细节和自定义功能,因此很难弄清楚实际需要什么。我怀疑很多人都面临过类似的问题。

因此,我要求提供一个

最小的工作示例

ExtensionArray
。该类应该通过 Pandas 提供的所有测试,以检查
ExtensionArray
的行为是否符合预期。我提供了下面测试的示例实现。

举一个具体的例子,假设我想扩展

ExtensionArray
以获得一个能够保存 NA 值的整数数组。这本质上是
IntegerArray
,但剥离了
ExtensionArray
基础知识之外的任何实际功能。


测试解决方案

我使用了以下装置和测试来测试解决方案的有效性。这些基于 Pandas 文档

中的指示
import operator

import numpy as np
from pandas import Series
import pytest

from pandas.tests.extension.base.casting import BaseCastingTests  # noqa
from pandas.tests.extension.base.constructors import BaseConstructorsTests  # noqa
from pandas.tests.extension.base.dtype import BaseDtypeTests  # noqa
from pandas.tests.extension.base.getitem import BaseGetitemTests  # noqa
from pandas.tests.extension.base.groupby import BaseGroupbyTests  # noqa
from pandas.tests.extension.base.interface import BaseInterfaceTests  # noqa
from pandas.tests.extension.base.io import BaseParsingTests  # noqa
from pandas.tests.extension.base.methods import BaseMethodsTests  # noqa
from pandas.tests.extension.base.missing import BaseMissingTests  # noqa
from pandas.tests.extension.base.ops import (  # noqa
    BaseArithmeticOpsTests,
    BaseComparisonOpsTests,
    BaseOpsUtil,
    BaseUnaryOpsTests,
)
from pandas.tests.extension.base.printing import BasePrintingTests  # noqa
from pandas.tests.extension.base.reduce import (  # noqa
    BaseBooleanReduceTests,
    BaseNoReduceTests,
    BaseNumericReduceTests,
)
from pandas.tests.extension.base.reshaping import BaseReshapingTests  # noqa
from pandas.tests.extension.base.setitem import BaseSetitemTests  # noqa

from .extension import NullableIntArray



@pytest.fixture
def dtype():
    """A fixture providing the ExtensionDtype to validate."""
    return 'NullableInt'


@pytest.fixture
def data():
    """
    Length-100 array for this type.
    * data[0] and data[1] should both be non missing
    * data[0] and data[1] should not be equal
    """
    return NullableIntArray(np.array(list(range(100))))


@pytest.fixture
def data_for_twos():
    """Length-100 array in which all the elements are two."""
    return NullableIntArray(np.array([2] * 2))


@pytest.fixture
def data_missing():
    """Length-2 array with [NA, Valid]"""
    return NullableIntArray(np.array([np.nan, 2]))


@pytest.fixture(params=["data", "data_missing"])
def all_data(request, data, data_missing):
    """Parametrized fixture giving 'data' and 'data_missing'"""
    if request.param == "data":
        return data
    elif request.param == "data_missing":
        return data_missing


@pytest.fixture
def data_repeated(data):
    """
    Generate many datasets.
    Parameters
    ----------
    data : fixture implementing `data`
    Returns
    -------
    Callable[[int], Generator]:
        A callable that takes a `count` argument and
        returns a generator yielding `count` datasets.
    """

    def gen(count):
        for _ in range(count):
            yield data

    return gen


@pytest.fixture
def data_for_sorting():
    """
    Length-3 array with a known sort order.
    This should be three items [B, C, A] with
    A < B < C
    """
    return NullableIntArray(np.array([2, 3, 1]))


@pytest.fixture
def data_missing_for_sorting():
    """
    Length-3 array with a known sort order.
    This should be three items [B, NA, A] with
    A < B and NA missing.
    """
    return NullableIntArray(np.array([2, np.nan, 1]))


@pytest.fixture
def na_cmp():
    """
    Binary operator for comparing NA values.
    Should return a function of two arguments that returns
    True if both arguments are (scalar) NA for your type.
    By default, uses ``operator.is_``
    """
    return operator.is_


@pytest.fixture
def na_value():
    """The scalar missing value for this type. Default 'None'"""
    return np.nan


@pytest.fixture
def data_for_grouping():
    """
    Data for factorization, grouping, and unique tests.
    Expected to be like [B, B, NA, NA, A, A, B, C]
    Where A < B < C and NA is missing
    """
    return NullableIntArray(np.array([2, 2, np.nan, np.nan, 1, 1, 2, 3]))


@pytest.fixture(params=[True, False])
def box_in_series(request):
    """Whether to box the data in a Series"""
    return request.param


@pytest.fixture(
    params=[
        lambda x: 1,
        lambda x: [1] * len(x),
        lambda x: Series([1] * len(x)),
        lambda x: x,
    ],
    ids=["scalar", "list", "series", "object"],
)
def groupby_apply_op(request):
    """
    Functions to test groupby.apply().
    """
    return request.param


@pytest.fixture(params=[True, False])
def as_frame(request):
    """
    Boolean fixture to support Series and Series.to_frame() comparison testing.
    """
    return request.param


@pytest.fixture(params=[True, False])
def as_series(request):
    """
    Boolean fixture to support arr and Series(arr) comparison testing.
    """
    return request.param


@pytest.fixture(params=[True, False])
def use_numpy(request):
    """
    Boolean fixture to support comparison testing of ExtensionDtype array
    and numpy array.
    """
    return request.param


@pytest.fixture(params=["ffill", "bfill"])
def fillna_method(request):
    """
    Parametrized fixture giving method parameters 'ffill' and 'bfill' for
    Series.fillna(method=<method>) testing.
    """
    return request.param


@pytest.fixture(params=[True, False])
def as_array(request):
    """
    Boolean fixture to support ExtensionDtype _from_sequence method testing.
    """
    return request.param


class TestCastingTests(BaseCastingTests):
    pass


class TestConstructorsTests(BaseConstructorsTests):
    pass



class TestDtypeTests(BaseDtypeTests):
    pass


class TestGetitemTests(BaseGetitemTests):
    pass


class TestGroupbyTests(BaseGroupbyTests):
    pass


class TestInterfaceTests(BaseInterfaceTests):
    pass


class TestParsingTests(BaseParsingTests):
    pass


class TestMethodsTests(BaseMethodsTests):
    pass


class TestMissingTests(BaseMissingTests):
    pass


class TestArithmeticOpsTests(BaseArithmeticOpsTests):
    pass


class TestComparisonOpsTests(BaseComparisonOpsTests):
    pass


class TestOpsUtil(BaseOpsUtil):
    pass


class TestUnaryOpsTests(BaseUnaryOpsTests):
    pass


class TestPrintingTests(BasePrintingTests):
    pass


class TestBooleanReduceTests(BaseBooleanReduceTests):
    pass


class TestNoReduceTests(BaseNoReduceTests):
    pass


class TestNumericReduceTests(BaseNumericReduceTests):
    pass


class TestReshapingTests(BaseReshapingTests):
    pass


class TestSetitemTests(BaseSetitemTests):
    pass
python pandas dataframe
2个回答
31
投票

更新2021-09-19

试图让

NullableIntArray
通过 测试套件 时遇到了太多问题,因此我创建了一个新示例 (
AngleDtype
+
AngleArray
),目前通过了 398 项测试(失败 2 项)。


0。使用方法

(熊猫1.3.2,numpy 1.20.2,python 3.9.2)

AngleArray
根据其
unit
(由
AngleDtype
表示)存储弧度或度数:

thetas = [0, np.pi, 2 * np.pi]
a = AngleArray(thetas, unit='rad')
# <AngleArray>
# [0.0, 3.141592653589793, 6.283185307179586]
# Length: 3, dtype: angle[rad]
a = a.asunit('deg')
# <AngleArray>
# [0.0, 180.0, 360.0]
# Length: 3, dtype: angle[deg]

AngleArray
可以存储在
Series
DataFrame
:

s = pd.Series(a)
# 0     0.0
# 1   180.0
# 2   360.0
# dtype: angle[deg]
df = pd.DataFrame({'a': s, 'b': AngleArray(thetas[::-1])})
#        a                  b
# 0    0.0  6.283185307179586
# 1  180.0  3.141592653589793
# 2  360.0                0.0
df['a']
# 0    0.0
# 1  180.0
# 2  360.0
# Name: a, dtype: angle[deg]
df['b']
# 0  6.283185307179586
# 1  3.141592653589793
# 2                0.0
# Name: b, dtype: angle[rad]

AngleArray
计算是
unit
感知的:

df['a + b'] = df['a'] + df['b']
#        a                  b  a + b
# 0    0.0  6.283185307179586  360.0
# 1  180.0  3.141592653589793  360.0
# 2  360.0                0.0  360.0
df['a + b']
# 0   360.0
# 1   360.0
# 2   360.0
# Name: a + b, dtype: angle[deg]

1.
AngleDtype

对于every

ExtensionDtype
,必须实现3个方法具体:

  • type
  • name
  • construct_array_type

对于 参数化

ExtensionDtype
(例如,
AngleDtype.unit
PeriodDtype.freq
):

对于测试套件:

  • __hash__
  • __eq__
  • __setstate__
from __future__ import annotations

import operator
import re
from typing import Any, Sequence

import numpy as np
import pandas as pd


@pd.api.extensions.register_extension_dtype
class AngleDtype(pd.core.dtypes.dtypes.PandasExtensionDtype):
    """
    An ExtensionDtype for unit-aware angular data.
    """
    # Required for all parameterized dtypes
    _metadata = ('unit',)
    _match = re.compile(r'(A|a)ngle\[(?P<unit>.+)\]')

    def __init__(self, unit=None):
        if unit is None:
            unit = 'rad'

        if unit not in ['rad', 'deg']:
            msg = f"'{type(self).__name__}' only supports 'rad' and 'deg' units"
            raise ValueError(msg)

        self._unit = unit

    def __str__(self) -> str:
        return f'angle[{self.unit}]'

    # TestDtypeTests
    def __hash__(self) -> int:
        return hash(str(self))

    # TestDtypeTests
    def __eq__(self, other: Any) -> bool:
        if isinstance(other, str):
            return self.name == other
        else:
            return isinstance(other, type(self)) and self.unit == other.unit

    # Required for pickle compat (see GH26067)
    def __setstate__(self, state) -> None:
        self._unit = state['unit']

    # Required for all ExtensionDtype subclasses
    @classmethod
    def construct_array_type(cls):
        """
        Return the array type associated with this dtype.
        """
        return AngleArray

    # Recommended for parameterized dtypes
    @classmethod
    def construct_from_string(cls, string: str) -> AngleDtype:
        """
        Construct an AngleDtype from a string.

        Example
        -------
        >>> AngleDtype.construct_from_string('angle[deg]')
        angle['deg']
        """
        if not isinstance(string, str):
            msg = f"'construct_from_string' expects a string, got {type(string)}"
            raise TypeError(msg)

        msg = f"Cannot construct a '{cls.__name__}' from '{string}'"
        match = cls._match.match(string)

        if match:
            d = match.groupdict()
            try:
                return cls(unit=d['unit'])
            except (KeyError, TypeError, ValueError) as err:
                raise TypeError(msg) from err
        else:
            raise TypeError(msg)

    # Required for all ExtensionDtype subclasses
    @property
    def type(self):
        """
        The scalar type for the array (e.g., int).
        """
        return np.generic

    # Required for all ExtensionDtype subclasses
    @property
    def name(self) -> str:
        """
        A string representation of the dtype.
        """
        return str(self)

    @property
    def unit(self) -> str:
        """
        The angle unit.
        """
        return self._unit

2.
AngleArray

对于every

ExtensionArray
,必须实现11个方法具体:

  • _from_sequence
  • _from_factorized
  • __getitem__
  • __len__
  • __eq__
  • dtype
  • nbytes
  • isna
  • take
  • copy
  • _concat_same_type

对于测试套件:

  • 需要许多更具体的方法
  • 每当测试提示我添加新方法时,我都会用注释对其进行标记(尽管这不是一个全面的映射,因为大多数方法都需要多个测试)
class AngleArray(pd.api.extensions.ExtensionArray):
    """
    An ExtensionArray for unit-aware angular data.
    """
    # Include `copy` param for TestInterfaceTests
    def __init__(self, data, unit='rad', copy: bool=False):
        self._data = np.array(data, copy=copy)
        self._unit = unit

    # Required for all ExtensionArray subclasses
    def __getitem__(self, index: int) -> AngleArray | Any:
        """
        Select a subset of self.
        """
        if isinstance(index, int):
            return self._data[index]
        else:
            # Check index for TestGetitemTests
            index = pd.core.indexers.check_array_indexer(self, index)
            return type(self)(self._data[index])

    # TestSetitemTests
    def __setitem__(self, index: int, value: np.generic) -> None:
        """
        Set one or more values in-place.
        """
        # Check index for TestSetitemTests
        index = pd.core.indexers.check_array_indexer(self, index)

        # Upcast to value's type (if needed) for TestMethodsTests
        if self._data.dtype < type(value):
            self._data = self._data.astype(type(value))

        # TODO: Validate value for TestSetitemTests
        # value = self._validate_setitem_value(value)

        self._data[index] = value

    # Required for all ExtensionArray subclasses
    def __len__(self) -> int:
        """
        Length of this array.
        """
        return len(self._data)

    # TestUnaryOpsTests
    def __invert__(self) -> AngleArray:
        """
        Element-wise inverse of this array.
        """
        data = ~self._data
        return type(self)(data, unit=self.dtype.unit)

    def _ensure_same_units(self, other) -> AngleArray:
        """
        Helper method to ensure `self` and `other` have the same units.
        """
        if isinstance(other, type(self)) and self.dtype.unit != other.dtype.unit:
            return other.asunit(self.dtype.unit)
        else:
            return other

    def _apply_operator(self, op, other, recast=False) -> np.ndarray | AngleArray:
        """
        Helper method to apply an operator `op` between `self` and `other`.

        Some ops require the result to be recast into AngleArray:
        * Comparison ops: recast=False
        * Arithmetic ops: recast=True
        """
        f = operator.attrgetter(op)
        data, other = np.array(self), np.array(self._ensure_same_units(other))
        result = f(data)(other)
        return result if not recast else type(self)(result, unit=self.dtype.unit)

    def _apply_operator_if_not_series(self, op, other, recast=False) -> np.ndarray | AngleArray:
        """
        Wraps _apply_operator only if `other` is not Series/DataFrame.
        
        Some ops should return NotImplemented if `other` is a Series/DataFrame:
        https://github.com/pandas-dev/pandas/blob/e7e7b40722e421ef7e519c645d851452c70a7b7c/pandas/tests/extension/base/ops.py#L115
        """
        if isinstance(other, (pd.Series, pd.DataFrame)):
            return NotImplemented
        else:
            return self._apply_operator(op, other, recast=recast)

    # Required for all ExtensionArray subclasses
    @pd.core.ops.unpack_zerodim_and_defer('__eq__')
    def __eq__(self, other):
        return self._apply_operator('__eq__', other, recast=False)

    # TestComparisonOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__ne__')
    def __ne__(self, other):
        return self._apply_operator('__ne__', other, recast=False)

    # TestComparisonOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__lt__')
    def __lt__(self, other):
        return self._apply_operator('__lt__', other, recast=False)

    # TestComparisonOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__gt__')
    def __gt__(self, other):
        return self._apply_operator('__gt__', other, recast=False)

    # TestComparisonOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__le__')
    def __le__(self, other):
        return self._apply_operator('__le__', other, recast=False)

    # TestComparisonOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__ge__')
    def __ge__(self, other):
        return self._apply_operator('__ge__', other, recast=False)
    
    # TestArithmeticOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__add__')
    def __add__(self, other) -> AngleArray:
        return self._apply_operator_if_not_series('__add__', other, recast=True)

    # TestArithmeticOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__sub__')
    def __sub__(self, other) -> AngleArray:
        return self._apply_operator_if_not_series('__sub__', other, recast=True)

    # TestArithmeticOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__mul__')
    def __mul__(self, other) -> AngleArray:
        return self._apply_operator_if_not_series('__mul__', other, recast=True)

    # TestArithmeticOpsTests
    @pd.core.ops.unpack_zerodim_and_defer('__truediv__')
    def __truediv__(self, other) -> AngleArray:
        return self._apply_operator_if_not_series('__truediv__', other, recast=True)

    # Required for all ExtensionArray subclasses
    @classmethod
    def _from_sequence(cls, data, dtype=None, copy: bool=False):
        """
        Construct a new AngleArray from a sequence of scalars.
        """
        if dtype is None:
            dtype = AngleDtype()

        if not isinstance(dtype, AngleDtype):
            msg = f"'{cls.__name__}' only supports 'AngleDtype' dtype"
            raise ValueError(msg)
        else:
            return cls(data, unit=dtype.unit, copy=copy)

    # TestParsingTests
    @classmethod
    def _from_sequence_of_strings(cls, strings, *, dtype=None, copy: bool=False) -> AngleArray:
        """
        Construct a new AngleArray from a sequence of strings.
        """
        scalars = pd.to_numeric(strings, errors='raise')
        return cls._from_sequence(scalars, dtype=dtype, copy=copy)

    # Required for all ExtensionArray subclasses
    @classmethod
    def _from_factorized(cls, uniques: np.ndarray, original: AngleArray):
        """
        Reconstruct an AngleArray after factorization.
        """
        return cls(uniques, unit=original.dtype.unit)

    # Required for all ExtensionArray subclasses
    @classmethod
    def _concat_same_type(cls, to_concat: Sequence[AngleArray]) -> AngleArray:
        """
        Concatenate multiple AngleArrays.
        """
        # ensure same units
        counts = pd.value_counts([array.dtype.unit for array in to_concat])
        unit = counts.index[0]

        if counts.size > 1:
            to_concat = [a.asunit(unit) for a in to_concat]

        return cls(np.concatenate(to_concat), unit=unit)

    # Required for all ExtensionArray subclasses
    @property
    def dtype(self):
        """
        An instance of AngleDtype.
        """
        return AngleDtype(self._unit)

    # Required for all ExtensionArray subclasses
    @property
    def nbytes(self) -> int:
        """
        The number of bytes needed to store this object in memory.
        """
        return self._data.nbytes

    @property
    def unit(self):
        return self.dtype.unit

    # Test*ReduceTests
    def all(self) -> bool:
        return all(self)

    def any(self) -> bool:  # Test*ReduceTests
        return any(self)

    def sum(self) -> np.generic:  # Test*ReduceTests
        return self._data.sum()

    def mean(self) -> np.generic:  # Test*ReduceTests
        return self._data.mean()

    def max(self) -> np.generic:  # Test*ReduceTests
        return self._data.max()

    def min(self) -> np.generic:  # Test*ReduceTests
        return self._data.min()

    def prod(self) -> np.generic:  # Test*ReduceTests
        return self._data.prod()

    def std(self) -> np.generic:  # Test*ReduceTests
        return pd.Series(self._data).std()

    def var(self) -> np.generic:  # Test*ReduceTests
        return pd.Series(self._data).var()

    def median(self) -> np.generic:  # Test*ReduceTests
        return np.median(self._data)

    def skew(self) -> np.generic:  # Test*ReduceTests
        return pd.Series(self._data).skew()

    def kurt(self) -> np.generic:  # Test*ReduceTests
        return pd.Series(self._data).kurt()

    # Test*ReduceTests
    def _reduce(self, name: str, *, skipna: bool=True, **kwargs):
        """
        Return a scalar result of performing the reduction operation.
        """
        f = operator.attrgetter(name)
        return f(self)()

    # Required for all ExtensionArray subclasses
    def isna(self):
        """
        A 1-D array indicating if each value is missing.
        """
        return pd.isnull(self._data)

    # Required for all ExtensionArray subclasses
    def copy(self):
        """
        Return a copy of the array.
        """
        copied = self._data.copy()
        return type(self)(copied, unit=self.unit)

    # Required for all ExtensionArray subclasses
    def take(self, indices, allow_fill=False, fill_value=None):
        """
        Take elements from an array.
        """
        if allow_fill and fill_value is None:
            fill_value = self.dtype.na_value

        result = pd.core.algorithms.take(self._data, indices, allow_fill=allow_fill,
                                         fill_value=fill_value)
        return self._from_sequence(result)

    # TestMethodsTests
    def value_counts(self, dropna: bool=True):
        """
        Return a Series containing descending counts of unique values (excludes NA values by default).
        """
        return pd.core.algorithms.value_counts(self._data, dropna=dropna)

    def asunit(self, unit: str) -> AngleArray:
        """
        Cast to an AngleDtype unit.
        """
        if unit not in ['rad', 'deg']:
            msg = f"'{type(self.dtype).__name__}' only supports 'rad' and 'deg' units"
            raise ValueError(msg)
        elif self.dtype.unit == unit:
            return self
        else:
            rad2deg = self.dtype.unit == 'rad' and unit == 'deg'
            data = np.rad2deg(self._data) if rad2deg else np.deg2rad(self._data)
            return type(self)(data, unit)

3.
pytest

$ pytest tests.py
...
2 failed, 398 passed, 1 skipped, 1 xfailed in 3.95s

还有两个测试失败:

  1. TestMethodsTests.test_combine_le

    目前这会返回一个

    AngleDtype
    一系列布尔值,但pandas希望该系列本身是布尔值(不知道如何在不破坏其他测试的情况下解决这个问题):

    pd.Series(a).combine(pd.Series(a), lambda x1, x2: x1 <= x2)
    
  2. TestSetitemTests.test_setitem_scalar_key_sequence_raise

    目前这会将

    a[[0, 1]]
    放入索引 0,但 pandas 预计会出现错误:

    a[0] = a[[0, 1]]
    

    一些 pandas 扩展数组使用复杂的验证方法来捕获这些边缘情况,例如:

import operator

import numpy as np
from pandas import Series
import pytest

from pandas.tests.extension.base.casting import BaseCastingTests  # noqa
from pandas.tests.extension.base.constructors import BaseConstructorsTests  # noqa
from pandas.tests.extension.base.dtype import BaseDtypeTests  # noqa
from pandas.tests.extension.base.getitem import BaseGetitemTests  # noqa
from pandas.tests.extension.base.groupby import BaseGroupbyTests  # noqa
from pandas.tests.extension.base.interface import BaseInterfaceTests  # noqa
from pandas.tests.extension.base.io import BaseParsingTests  # noqa
from pandas.tests.extension.base.methods import BaseMethodsTests  # noqa
from pandas.tests.extension.base.missing import BaseMissingTests  # noqa
from pandas.tests.extension.base.ops import (  # noqa
    BaseArithmeticOpsTests,
    BaseComparisonOpsTests,
    BaseOpsUtil,
    BaseUnaryOpsTests,
)
from pandas.tests.extension.base.printing import BasePrintingTests  # noqa
from pandas.tests.extension.base.reduce import (  # noqa
    BaseBooleanReduceTests,
    BaseNoReduceTests,
    BaseNumericReduceTests,
)
from pandas.tests.extension.base.reshaping import BaseReshapingTests  # noqa
from pandas.tests.extension.base.setitem import BaseSetitemTests  # noqa

from extension import AngleDtype, AngleArray


@pytest.fixture
def dtype():
    """
    A fixture providing the ExtensionDtype to validate.
    """
    return AngleDtype()


@pytest.fixture
def data():
    """
    Length-100 array for this type.
    * data[0] and data[1] should both be non missing
    * data[0] and data[1] should not be equal
    """
    return AngleArray(np.arange(100))


@pytest.fixture
def data_for_twos():
    """
    Length-100 array in which all the elements are two.
    """
    return AngleArray(np.array([2] * 100))


@pytest.fixture
def data_missing():
    """
    Length-2 array with [NA, Valid].
    """
    return AngleArray(np.array([np.nan, 2]))


@pytest.fixture(params=['data', 'data_missing'])
def all_data(request, data, data_missing):
    """
    Parameterized fixture giving 'data' and 'data_missing'.
    """
    if request.param == 'data':
        return data
    elif request.param == 'data_missing':
        return data_missing


@pytest.fixture
def data_repeated(data):
    """
    Generate many datasets.

    Parameters
    ----------
    data : fixture implementing `data`

    Returns
    -------
    Callable[[int], Generator]:
        A callable that takes a `count` argument and
        returns a generator yielding `count` datasets.
    """
    def gen(count):
        for _ in range(count):
            yield data

    return gen


@pytest.fixture
def data_for_sorting():
    """
    Length-3 array with a known sort order.
    This should be three items [B, C, A] with A < B < C.
    """
    return AngleArray(np.array([2, 3, 1]))


@pytest.fixture
def data_missing_for_sorting():
    """
    Length-3 array with a known sort order.
    This should be three items [B, NA, A] with A < B and NA missing.
    """
    return AngleArray(np.array([2, np.nan, 1]))


@pytest.fixture
def na_cmp():
    """
    Binary operator for comparing NA values.
    Should return a function of two arguments that returns
    True if both arguments are (scalar) NA for your type.
    By default, uses ``operator.is_``.
    """
    return lambda a, b: np.array_equal(a, b, equal_nan=True)


@pytest.fixture
def na_value():
    """
    The scalar missing value for this type. Default 'None'.
    """
    return np.nan


@pytest.fixture
def data_for_grouping():
    """
    Data for factorization, grouping, and unique tests.
    Expected to be like [B, B, NA, NA, A, A, B, C] where A < B < C and NA is missing.
    """
    return AngleArray(np.array([2, 2, np.nan, np.nan, 1, 1, 2, 3]))


@pytest.fixture(params=[True, False])
def box_in_series(request):
    """
    Whether to box the data in a Series.
    """
    return request.param


@pytest.fixture(
    params=[
        lambda x: 1,
        lambda x: [1] * len(x),
        lambda x: Series([1] * len(x)),
        lambda x: x,
    ],
    ids=['scalar', 'list', 'series', 'object'],
)
def groupby_apply_operator(request):
    """
    Functions to test groupby.apply().
    """
    return request.param


@pytest.fixture(params=[True, False])
def as_frame(request):
    """
    Boolean fixture to support Series and Series.to_frame() comparison testing.
    """
    return request.param


@pytest.fixture(params=[True, False])
def as_series(request):
    """
    Boolean fixture to support arr and Series(arr) comparison testing.
    """
    return request.param


@pytest.fixture(params=[True, False])
def use_numpy(request):
    """
    Boolean fixture to support comparison testing of ExtensionDtype array    and numpy array.
    """
    return request.param


@pytest.fixture(params=['ffill', 'bfill'])
def fillna_method(request):
    """
    Parameterized fixture giving method parameters 'ffill' and 'bfill' for
    Series.fillna(method=<method>) testing.
    """
    return request.param


@pytest.fixture(params=[True, False])
def as_array(request):
    """
    Boolean fixture to support ExtensionDtype _from_sequence method testing.
    """
    return request.param


@pytest.fixture(params=[None, lambda x: x])
def sort_by_key(request):
    """
    Simple fixture for testing keys in sorting methods.
    Tests None (no key) and the identity key.
    """
    return request.param


# TODO: Finish implementing all operators
_all_arithmetic_operators = [
    '__add__',
    #  '__radd__',
    '__sub__',
    #  '__rsub__',
    '__mul__',
    #  '__rmul__',
    #  '__floordiv__',
    #  '__rfloordiv__',
    '__truediv__',
    #  '__rtruediv__',
    #  '__pow__',
    #  '__rpow__',
    #  '__mod__',
    #  '__rmod__',
]
@pytest.fixture(params=_all_arithmetic_operators)
def all_arithmetic_operators(request):
    """
    Fixture for dunder names for common arithmetic operations.
    """
    return request.param


_all_numeric_reductions = [
    'sum',
    'max',
    'min',
    'mean',
    'prod',
    'std',
    'var',
    'median',
    'kurt',
    'skew',
]
@pytest.fixture(params=_all_numeric_reductions)
def all_numeric_reductions(request):
    """
    Fixture for numeric reduction names.
    """
    return request.param


_all_boolean_reductions = ['all', 'any']
@pytest.fixture(params=_all_boolean_reductions)
def all_boolean_reductions(request):
    """
    Fixture for boolean reduction names.
    """
    return request.param


_all_reductions = _all_numeric_reductions + _all_boolean_reductions
@pytest.fixture(params=_all_reductions)
def all_reductions(request):
    """
    Fixture for all (boolean + numeric) reduction names.
    """
    return request.param


_all_compare_operators = [
    '__eq__',
    '__ne__',
    '__le__',
    '__lt__',
    '__ge__',
    '__gt__',
]
@pytest.fixture(params=_all_compare_operators)
def all_compare_operators(request):
    """
    Fixture for dunder names for common compare operations:

    * >=
    * >
    * ==
    * !=
    * <
    * <=
    """
    return request.param


class TestCastingTests(BaseCastingTests):
    pass


class TestConstructorsTests(BaseConstructorsTests):
    pass


class TestDtypeTests(BaseDtypeTests):
    pass


class TestGetitemTests(BaseGetitemTests):
    pass


class TestGroupbyTests(BaseGroupbyTests):
    pass


class TestInterfaceTests(BaseInterfaceTests):
    pass


class TestParsingTests(BaseParsingTests):
    pass


class TestMethodsTests(BaseMethodsTests):
    pass


class TestMissingTests(BaseMissingTests):
    pass


class TestArithmeticOpsTests(BaseArithmeticOpsTests):
    series_scalar_exc = None
    frame_scalar_exc = None
    series_array_exc = None
    divmod_exc = TypeError  # TODO: Implement divmod


class TestComparisonOpsTests(BaseComparisonOpsTests):
    # See pint-pandas test suite
    def _compare_other(self, s, data, op_name, other):
        op = self.get_op_from_name(op_name)
        result = op(s, other)
        expected = op(s.to_numpy(), other)
        assert (result == expected).all()


class TestOpsUtil(BaseOpsUtil):
    pass


class TestUnaryOpsTests(BaseUnaryOpsTests):
    pass


class TestPrintingTests(BasePrintingTests):
    pass


class TestBooleanReduceTests(BaseBooleanReduceTests):
    pass


class TestNumericReduceTests(BaseNumericReduceTests):
    pass


# AFAICT NoReduce and Boolean+NumericReduce are mutually exclusive
# class TestNoReduceTests(BaseNoReduceTests):
    # pass


class TestReshapingTests(BaseReshapingTests):
    pass


class TestSetitemTests(BaseSetitemTests):
    pass

0
投票

无法与当前代表发表评论。添加到@tdy的优秀答案,您可以使用与StringArray的实现相同的方法抛出

test_setitem_scalar_key_sequence_raise
所需的特定错误。

def __setitem__(self, key, value): # (1) Do whatever you must to get the right value to assign value = extract_array(value, extract_numpy=True) if isinstance(value, type(self)): # extract_array doesn't extract PandasArray subclasses value = value._ndarray # (2) Get the indexer key = check_array_indexer(self, key) # (3) Raise the required error if they are not both scalar. scalar_key = lib.is_scalar(key) scalar_value = lib.is_scalar(value) if scalar_key and not scalar_value: raise ValueError("setting an array element with a sequence.")
    
© www.soinside.com 2019 - 2024. All rights reserved.