有人可以告诉我如何在 FastAPI 路由器方法之外为我的
get_db()
使用依赖注入吗?显然,Depends()
仅涵盖请求函数中的 DI。
这是
get_db()
异步生成器:
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
yield session
在 FastAPI 路由器中,我可以简单地使用
Depends()
,如下所示:
@router.get("/interactions", response_model=List[schemas.Interaction])
async def get_all_interactions(db: Annotated[AsyncSession, Depends(get_db)]) -> List[schemas.Interaction]:
interactions = await crud.get_interactions(db=db)
return [
schemas.Interaction.model_validate(interaction) for interaction in interactions
]
现在,在请求之外,如何在新方法中注入
get_db
并摆脱方法内部的 async for
?
@cli.command(name="create_superuser")
async def create_superuser(): # Note: how to pass db session here as param?
username = click.prompt("Username", type=str)
email = click.prompt("Email (optional)", type=str, default="")
password = getpass("Password: ")
confirm_password = getpass("Confirm Password: ")
if password != confirm_password:
click.echo("Passwords do not match")
return
async for db in database.get_db(): # Note: remove it from here
user = schemas.UserAdminCreate(
username=username,
email=None if not email else email,
password=password,
role="admin",
)
await crud.create_user(db=db, user=user)
PS:这个要求的原因是,我将为
create_superuser()
函数编写一个测试用例,它有自己的数据库和各自的会话,因此将会话数据库注入到任何方法。
最终,我可以实现一个简单的依赖注入器来解析异步生成器并随后到达数据库异步会话,几乎与 FastAPI 中的
Depends()
类似:
import asyncclick as click
import inspect
from getpass import getpass
from . import crud, schemas, database
class Provide:
def __init__(self, value):
self.value = value
def inject_db(f):
sig = inspect.signature(f)
async def wrapper(*args, **kwargs):
for param in sig.parameters.values():
if isinstance(param.default, Provide):
async for db in param.default.value():
kwargs[param.name] = db
await f(*args, **kwargs)
return wrapper
@cli.command(name="create_superuser")
@inject_db
async def create_superuser(db: database.AsyncSession = Provide(database.get_db)):
username = click.prompt("Username", type=str)
email = click.prompt("Email (optional)", type=str, default="")
password = getpass("Password: ")
confirm_password = getpass("Confirm Password: ")
user = schemas.UserAdminCreate(
username=username,
email=None if not email else email,
password=password,
role="admin",
)
await crud.create_user(db=db, user=user)