类型提示装饰器,它注入值,但也支持传递值

问题描述 投票:0回答:1

我正在尝试实现一个注入

DBConnection
的装饰器。我面临的问题是我想同时支持两者:根据装饰器传递参数 来注入它。我尝试用
@overload
这样做,但失败了。这是可重现的代码:

from functools import wraps
from typing import Awaitable, Callable, Concatenate, ParamSpec, TypeVar

from typing_extensions import reveal_type


class DBConnection:
    ...


T = TypeVar("T")
P = ParamSpec("P")


def inject_db_connection(
    f: Callable[Concatenate[DBConnection, P], Awaitable[T]]
) -> Callable[P, Awaitable[T]]:
    @wraps(f)
    async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
        signature = inspect.signature(f).parameters
        passed_args = dict(zip(signature, args))
        if "db_connection" in kwargs or "db_connection" in passed_args:
            return await f(*args, **kwargs)

        return await f(DBConnection(), *args, **kwargs)

    return inner


@inject_db_connection
async def get_user(db_connection: DBConnection, user_id: int) -> dict:
    assert db_connection
    return {"user_id": user_id}


async def main() -> None:
    # ↓ No issue, great!
    user1 = await get_user(user_id=1)

    # ↓ Understandably fails with:
    # `Unexpected keyword argument "db_connection" for "get_user"  [call-arg]`
    # but I would like to support passing `db_connection` explicitly as well.
    db_connection = DBConnection()
    user2 = await get_user(db_connection=db_connection, user_id=1)

    # ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
    reveal_type(user1)
    # ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
    reveal_type(user2)```

python types type-hinting mypy
1个回答
0
投票

使用此问题中描述的技巧:使用 kwargs 为函数键入签名(typing.Callable)

我设法编写了一个通过类型检查的版本,尽管它确实需要一次强制转换,但我不知道如何在这里避免这种情况。我还冒昧地改变了

inner
内部的实际功能部分,使其更加健壮。

class DBConnection:
    ...


T = TypeVar("T", covariant=True)
P = ParamSpec("P")


class CallMaybeDB(Protocol[P, T]):
    @overload
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

    @overload
    def __call__(self, db_connection: DBConnection, *args: P.args, **kwargs: P.kwargs) -> T: ...

    # Unsure if this is required, my IDE complains if it's missing, but mypy doesn't
    def __call__(self, *args, **kwargs) -> T: ...

def inject_db_connection(
        f: Callable[Concatenate[DBConnection, P], T]
) -> CallMaybeDB[P, T]:
    signature = inspect.signature(f)
    if "db_connection" not in signature.parameters:
        raise TypeError("Function should expect db_connection parameter")

    @wraps(f)
    def inner(*args, **kwargs) -> T:
        bound = signature.bind_partial(*args, **kwargs)
        if "db_connection" not in bound.arguments:
            bound.arguments["db_connection"] = DBConnection()

        return f(*bound.args, **bound.kwargs)

    return cast(CallMaybeDB[P, T], inner)

© www.soinside.com 2019 - 2024. All rights reserved.