我正在尝试实现一个注入
DBConnection
的装饰器。我面临的问题是我想同时支持两者:根据装饰器传递参数 和 来注入它。我尝试用 @overload
这样做,但失败了。这是可重现的代码:
from functools import wraps
from typing import Awaitable, Callable, Concatenate, ParamSpec, TypeVar
from typing_extensions import reveal_type
class DBConnection:
...
T = TypeVar("T")
P = ParamSpec("P")
def inject_db_connection(
f: Callable[Concatenate[DBConnection, P], Awaitable[T]]
) -> Callable[P, Awaitable[T]]:
@wraps(f)
async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
signature = inspect.signature(f).parameters
passed_args = dict(zip(signature, args))
if "db_connection" in kwargs or "db_connection" in passed_args:
return await f(*args, **kwargs)
return await f(DBConnection(), *args, **kwargs)
return inner
@inject_db_connection
async def get_user(db_connection: DBConnection, user_id: int) -> dict:
assert db_connection
return {"user_id": user_id}
async def main() -> None:
# ↓ No issue, great!
user1 = await get_user(user_id=1)
# ↓ Understandably fails with:
# `Unexpected keyword argument "db_connection" for "get_user" [call-arg]`
# but I would like to support passing `db_connection` explicitly as well.
db_connection = DBConnection()
user2 = await get_user(db_connection=db_connection, user_id=1)
# ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
reveal_type(user1)
# ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
reveal_type(user2)```
使用此问题中描述的技巧:使用 kwargs 为函数键入签名(typing.Callable)
我设法编写了一个通过类型检查的版本,尽管它确实需要一次强制转换,但我不知道如何在这里避免这种情况。我还冒昧地改变了
inner
内部的实际功能部分,使其更加健壮。
class DBConnection:
...
T = TypeVar("T", covariant=True)
P = ParamSpec("P")
class CallMaybeDB(Protocol[P, T]):
@overload
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
@overload
def __call__(self, db_connection: DBConnection, *args: P.args, **kwargs: P.kwargs) -> T: ...
# Unsure if this is required, my IDE complains if it's missing, but mypy doesn't
def __call__(self, *args, **kwargs) -> T: ...
def inject_db_connection(
f: Callable[Concatenate[DBConnection, P], T]
) -> CallMaybeDB[P, T]:
signature = inspect.signature(f)
if "db_connection" not in signature.parameters:
raise TypeError("Function should expect db_connection parameter")
@wraps(f)
def inner(*args, **kwargs) -> T:
bound = signature.bind_partial(*args, **kwargs)
if "db_connection" not in bound.arguments:
bound.arguments["db_connection"] = DBConnection()
return f(*bound.args, **bound.kwargs)
return cast(CallMaybeDB[P, T], inner)