在 Python 类型中,有没有一种方法可以指定允许的依赖泛型类型的组合?

问题描述 投票:0回答:2

我在 Python 中有一个非常小但通用(可能太多)的类,它实际上有两个约束在一起的泛型类型。我的代码(但打字很糟糕)代码应该显示我的意图:

class ApplyTo:
    def __init__(
        self,
        *transforms: Callable[..., Any],
        to: Any | Sequence[Any],
        dispatch: Literal['separate', 'joint'] = 'separate',
    ):
        self._transform = TransformsPipeline(*transforms)
        self._to = to if isinstance(to, Sequence) else [to]
        self._dispatch = dispatch

    def __call__(self, data: MutableSequence | MutableMapping):
        if self._dispatch == 'separate':
            for key in self._to:
                data[key] = self._transform(data[key])
            return data

        if self._dispatch == 'joint':
            args = [data[key] for key in self._to]
            return self._transform(*args)

        assert False

我已经仔细检查过它在运行时是否有效并且非常简单,但是打字真的很可怕。

所以想法是,当我们将

to
设置为
int
时,那么
data
应该是
MutableSequence | MutableMapping[int, Any]
;当
to
Hashable
时,那么
data
应该是
MutableMapping[Hashable or whatever type of to is, Any]
。我知道
int
Hashable
这并不会让这件事变得更容易。

我非常糟糕的打字尝试看起来像这样:

T = TypeVar('T', bound=Hashable | int)
C = TypeVar('C', bound=MutableMapping[T, Any] | MutableSequence)


class ApplyTo(Generic[C, T]):
    def __init__(
        self,
        *transforms: Callable[..., Any],
        to: T | Sequence[T],
        dispatch: Literal['separate', 'joint'] = 'separate',
    ):
        self._transform = TransformsPipeline(*transforms)
        self._to = to if isinstance(to, Sequence) else [to]
        self._dispatch = dispatch

    def __call__(self, data: C):
        if self._dispatch == 'separate':
            for key in self._to:
                data[key] = self._transform(input[key])
            return input

        if self._dispatch == 'joint':
            args = [data[key] for key in self._to]
            return self._transform(*args)

        assert False

这让 mypy 抱怨(毫不奇怪):

error: Type variable "task_driven_sr.transforms.generic.T" is unbound  [valid-type]
note: (Hint: Use "Generic[T]" or "Protocol[T]" base class to bind "T" inside a class)
note: (Hint: Use "T" in function signature to bind "T" inside a function)

有没有办法正确输入提示并以某种方式将

to
data
的类型绑定在一起?也许我的方法有缺陷,我已经走进了死胡同。

python mypy python-typing
2个回答
0
投票

如果我没理解错的话,让我哭的那句台词是:

C = TypeVar('C', bound=MutableMapping[T, Any] | MutableSequence)

还有一些其他主题处理类似的问题: 类型变量...未绑定

您可以在此处找到有关 mypy 文档的更多详细信息: https://mypy.readthedocs.io/en/latest/generics.html#generic-type-aliases

基本上,mypy 告诉您它不能将 TypeVar 与另一个 TypeVar 绑定(请记住 TypeVar 应该与 Generic 一起使用)。

我认为您的问题的解决方案可能是使用通用别名:

from typing import Any, Callable, Literal, Generic, Hashable, MutableMapping, MutableSequence, Sequence, TypeVar

T = TypeVar('T', bound=Hashable | int)
C = MutableMapping[T, Any] | MutableSequence  # GenericAlias !


class ApplyTo(Generic[T]):
    def __init__(to: T | Sequence[T]):
        self._to = to if isinstance(to, Sequence) else [to]

    def __call__(self, data: C):
        ...

然后,您可以像这样使用ApplyTo:

ApplyTo[int](to=5)
ApplyTo[int](to=[5, 6, 7])

如果我没有正确理解您的问题或不清楚,请告诉我!


0
投票

您可以将

MutableMapping
MutableSequence
的并集替换为
Protocol

import typing

DispatchType = typing.Literal['separate', 'joint']

# `P` must be declared with `contravariant=True`, otherwise it errors with
# 'Invariant type variable "P" used in protocol where contravariant one is expected'
K = typing.TypeVar('K', contravariant=True)
class Indexable(typing.Protocol[K]):
    def __getitem__(self, key: K):
        pass
    
    def __setitem__(self, key: K, value: typing.Any):
        pass

# Accepts only hashable types (including `int`s)
H = typing.TypeVar('H', bound=typing.Hashable)
class ApplyTo(typing.Generic[H]):
    _to: typing.Sequence[H]
    _dispatch: DispatchType
    _transform: typing.Callable[..., typing.Any]  # TODO Initialize `_transform`

    def __init__(self, to: typing.Sequence[H] | H, dispatch: DispatchType = 'separate') -> None:
        self._dispatch = dispatch
        self._to = to if isinstance(to, typing.Sequence) else [to]

    def __call__(self, data: Indexable[H]) -> typing.Any:
        if self._dispatch == 'separate':
            for key in self._to:
                data[key] = self._transform(data[key])
            return data

        if self._dispatch == 'joint':
            args = [data[key] for key in self._to]
            return self._transform(*args)

        assert False

用途:

def main() -> None:
    r0 = ApplyTo(to=0)([1, 2, 3])
    # typechecks
    r0 = ApplyTo(to=0)({1: 'a', 2: 'b', 3: 'c'})
    # typechecks

    r1 = ApplyTo(to='a')(['b', 'c', 'd'])
    # does not typecheck: Argument 1 to "__call__" of "Applier" has incompatible type "list[str]"; expected "Indexable[str]"
    r1 = ApplyTo(to='a')({'b': 1, 'c': 2, 'd': 3}) 
    # typechecks
© www.soinside.com 2019 - 2024. All rights reserved.