Python 中 2D 转置函数的正确类型注释

问题描述 投票:0回答:1

在我们的代码中,我们做了很多 2D 转置,例如

zip(*something)
,其中
something
是元组列表。

例如:

>>> a = [('a', 1), ('b', 2), ('c', 3)]
>>> result = tuple(*zip(*a))
>>> result # type checkers should reveal tuple[Iterable[str], Iterable[int]]
[('a', 'b', 'c'), (1, 2, 3)]

在 MyPy(PlayIssue)和 Pyright(Play)中,

tuple(zip(*a))
的类型显示类似
tuple[Any]
的内容。这会导致代码丢失所有类型信息,这是我不想要的。

因此,我想创建一个特殊的函数

def transpose(iterable)
函数来确保类型检查器了解结果类型。 一种不丢失类型信息的解决方法。

如何做到这一点?

python python-typing
1个回答
0
投票

一个不太优雅但有效的解决方案是依赖于重载

只要元组中的元素数量很少(在我的示例中为 5),这就会起作用。

我尝试使用

TypeVarTuple
而不是 overloads 但无法让它工作。

from typing import TypeVar, Iterable

T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")


@overload
def transpose(
    iterable: Iterable[tuple[T1, T2, T3, T4, T5]], strict: bool = False
) -> tuple[Iterable[T1], Iterable[T2], Iterable[T3], Iterable[T4], Iterable[T5]]:
    ...


@overload
def transpose(
    iterable: Iterable[tuple[T1, T2, T3, T4]], strict: bool = False
) -> tuple[Iterable[T1], Iterable[T2], Iterable[T3], Iterable[T4]]:
    ...


@overload
def transpose(
    iterable: Iterable[tuple[T1, T2, T3]], strict: bool = False
) -> tuple[Iterable[T1], Iterable[T2], Iterable[T3]]:
    ...


@overload
def transpose(
    iterable: Iterable[tuple[T1, T2]], strict: bool = False
) -> tuple[Iterable[T1], Iterable[T2]]:
    ...


def transpose(
    iterable: (
        Iterable[tuple[T1, T2]]
        | Iterable[tuple[T1, T2, T3]]
        | Iterable[tuple[T1, T2, T3, T4]]
        | Iterable[tuple[T1, T2, T3, T4, T5]]
    ),
    strict: bool = False,
) -> (
    tuple[Iterable[T1], Iterable[T2]]
    | tuple[Iterable[T1], Iterable[T2], Iterable[T3]]
    | tuple[Iterable[T1], Iterable[T2], Iterable[T3], Iterable[T4]]
    | tuple[Iterable[T1], Iterable[T2], Iterable[T3], Iterable[T4], Iterable[T5]]
):
    """
    Transpose the elements of given iterable, type safe

    Only a typed shortcut for zip(*iterable)
    See https://github.com/python/mypy/issues/5247 for background
    """
    return zip(*iterable, strict=strict)  # type: ignore

© www.soinside.com 2019 - 2024. All rights reserved.