假设我有一个 SQLalchemy 模型,例如:
from app.data_structures.base import (
Base
class User(Base):
__tablename__ = "users"
user_name: Mapped[str] = mapped_column(primary_key=True, nullable=True)
flag = Column(Boolean)
def __init__(
self,
user_name: str = None,
flag: bool = false(),
)-> None:
self.user_name = user_name
self.flag = flag
其中app.data_structs.base.py:
from contextlib import contextmanager
from os import environ
from os.path import join, realpath
from sqlalchemy import Column, ForeignKey, Table, create_engine
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
db_name = environ.get("DB_NAME")
ROOT_DIR =environ.get("ROOT_DIR")
db_path = realpath(join(ROOT_DIR, "data", db_name))
engine = create_engine(f"sqlite:///{db_path}", connect_args={"timeout": 120})
session_factory = sessionmaker(bind=engine)
sql_session = scoped_session(session_factory)
@contextmanager
def Session():
session = sql_session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
Base = declarative_base()
然后我定义了一个函数(app.helpers.db_helper.get_users_to_query),在其他地方执行类似的操作:
from app.data_structures.base import Session
def get_usernames_to_query(flag: bool = False) -> List[User]:
logger = getLogger(__name__)
try:
with Session() as session:
usernames_to_query = [
{"user_name": user.user_name}
for user in session.query(User).filter(
User.flag == flag
)
]
except Exception as err:
logger.exception(f"Exception thrown in get_users_to_query {' '.join(err.args)}")
usernames_to_query = []
return usernames_to_query
我正在为此应用程序编写单元测试,并尝试模拟 SQLite 数据库和 sqlalchemy,我正在执行以下操作:
from app.helpers.db_helper.get_users_to_query import get_users_to_query
class TestSearchUser(UserHelperTestCase):
def setUp(self):
super().setUp()
def tearDown(self) -> None:
return super().tearDown()
@patch("app.data_structures.base.Session")
def test_get_users_to_query(self, mock_session) -> None:
self.engine = create_engine("sqlite:///:memory:")
self.session = Session(self.engine)
Base.metadata.create_all(self.engine)
mock_session.return_value.__enter__.return_value = (
self.session
)
(patcher, environ_dict, environ_mock_get) = self.environ_mock_get_factory()
with patcher():
fake_user1 = User(
display_name="Jane Doe", user_name="jdb1", flag=False
)
fake_user2 = User(
display_name="John Doe", user_name="jdb2", flag=True
)
with mock_session as session:
session.add(fake_user1)
session.add(fake_user2)
session.commit()
users_to_query = get_users_to_query()
print(users_to_query)
但是我从测试返回的查询用户始终是来自生产数据库的用户,因此我设置的内存数据库和会话没有被使用。测试这种设置的最佳方法是什么?我也尝试过模拟会话并设置返回值,但没有成功。
您可以嘲笑
query
+ filter()
。这是一个例子:
pip 安装 sqlalchemy==2.0.20
from unittest import TestCase, mock
from unittest.mock import Mock
from sqlalchemy import String, create_engine
from sqlalchemy.orm import Mapped, DeclarativeBase, Session
from sqlalchemy.orm._orm_constructors import mapped_column
engine = create_engine('sqlite:///:memory:', echo=True)
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = 'user_account'
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30))
with Session(engine) as session:
# create a user for demo
Base.metadata.create_all(engine)
session.add_all([User(name='super sonic')])
session.commit()
class TestExample(TestCase):
def test_without_mock(self):
# check if the user exists
with Session(engine) as session:
users = list(session.query(User).filter(User.name == 'super sonic'))
self.assertEqual(len(users), 1)
def test_mock(self):
with mock.patch(
'sqlalchemy.orm.session.Session.query',
# mock session.query
return_value=Mock(
# mock filter()
filter=Mock(
side_effect=(
[User(name='mock1'), User(name='mock2')], # result of first filter() call
[User(name='mock3'), User(name='mock4')], # result of second filter() call
),
),
),
):
self.assertListEqual(
[
dict(name='mock1'),
dict(name='mock2'),
],
[
dict(name=u.name) for u in
session.query(User).filter(User.name == 'super sonic')
]
)
self.assertListEqual(
[
dict(name='mock3'),
dict(name='mock4'),
],
[
dict(name=u.name) for u in
session.query(User).filter(User.name == 'super sonic')
]
)
运行
pytest {YOUR_SCRIPT_NAME}.py
:
=========================================================== test session starts ============================================================
platform linux -- Python 3.10.6, pytest-7.4.0, pluggy-1.3.0
rootdir: /home/{PATH_TO_PROJECT}
plugins: anyio-3.7.1
collected 2 items
{YOUR_SCRIPT_NAME}.py .. [100%]
============================================================ 2 passed in 0.19s =============================================================