我有那个数据库:
class Category(Base):
__tablename__ = 'category'
id = Column(Integer, primary_key=True, nullable=False)
title = Column(Text, nullable=False)
class Filter(Base):
__tablename__ = 'filter'
id = Column(Integer, primary_key=True, nullable=False)
category_id = Column(ForeignKey("category.id"), nullable=False)
title = Column(Text, nullable=False)
transcription = Column(Text, nullable=False)
category = relationship('Category', backref='filters')
class FilterParameter(Base):
__tablename__ = 'fparam'
id = Column(Integer, primary_key=True, nullable=False)
value = Column(Text, nullable=False)
feature_id = Column(ForeignKey("filter.id"), nullable=False)
transcription = Column(Text, nullable=False)
filter = relationship('Filter', backref="fparams")
class Product(Base):
__tablename__ = 'product'
id = Column(Integer, primary_key=True, nullable=False)
title = Column(Text, nullable=False)
pic = Column(Text)
price = Column(Float, nullable=False)
availability = Column(Boolean, nullable=False)
keywords = Column(Text, nullable=False)
class ProductParameter(Base):
__tablename__ = 'pparam'
id = Column(Integer, primary_key=True, nullable=False)
product_id = Column(ForeignKey("product.id"), nullable=False)
param_id = Column(ForeignKey("fparam.id"), nullable=False)
product = relationship('Product', backref="pparams")
目标是通过多个参数(和_)过滤产品列表。一个产品有一系列不同的参数(一对多),这些参数位于 FilterParameter 表中。
我认为我已经接近目标了,因为按参数类型之一进行过滤是有效的:
res = cursor.session.query(Product).join(ProductParameter, Product.id == ProductParameter.product_id)\
.join(FilterParameter, FilterParameter.id == ProductParameter.param_id)\
.join(Filter, Filter.id == FilterParameter.feature_id)
res = res.filter(FilterParameter.transcription.in_("acer,asus".split(',')),
Filter.transcription == "brand").all()
它有效。
但是如果你需要按几种类型的参数进行过滤:
res = res.filter(and_(FilterParameter.transcription.in_("acer,asus".split(',')),
Filter.transcription == "brand"))\
.filter(and_(FilterParameter.transcription.in_("1920x1080".split(',')),
Filter.transcription == "pixel")).all()
这不起作用。
这是一个看起来更复杂的问题。我认为你可以使用连接、子查询或数组列(如果你使用的话,在 postgresql 中)。
设置模型并填充数据库
import os
from sqlalchemy import (
create_engine,
Column,
Integer,
ForeignKey,
Text,
)
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.sql import (
and_,
)
from sqlalchemy.orm import (
declarative_base,
Session,
relationship,
aliased,
)
def get_engine(env):
return create_engine(f"postgresql+psycopg2://{env['DB_USER']}:{env['DB_PASSWORD']}@{env['DB_HOST']}:{env['DB_PORT']}/{env['DB_NAME']}", echo=True)
Base = declarative_base()
class Filter(Base):
__tablename__ = 'filter'
id = Column(Integer, primary_key=True, nullable=False)
transcription = Column(Text, nullable=False)
class FilterParameter(Base):
__tablename__ = 'fparam'
id = Column(Integer, primary_key=True, nullable=False)
feature_id = Column(ForeignKey("filter.id"), nullable=False)
transcription = Column(Text, nullable=False)
filter = relationship('Filter', backref="fparams")
class Product(Base):
__tablename__ = 'product'
id = Column(Integer, primary_key=True, nullable=False)
title = Column(Text, nullable=False)
tags = Column(ARRAY(Text))
class ProductParameter(Base):
__tablename__ = 'pparam'
id = Column(Integer, primary_key=True, nullable=False)
product_id = Column(ForeignKey("product.id"), nullable=False)
param_id = Column(ForeignKey("fparam.id"), nullable=False)
product = relationship('Product', backref="pparams")
def populate(engine):
Base.metadata.create_all(engine)
with Session(engine) as session:
brand = Filter(transcription='brand')
session.add(brand)
brand_params = [FilterParameter(filter=brand, transcription=name) for name in ['acer', 'asus']]
session.add_all(brand_params)
pixel = Filter(transcription='pixel')
session.add(pixel)
pixel_params = [FilterParameter(filter=pixel, transcription=name) for name in ['1920x1080']]
session.add_all(pixel_params)
session.flush()
p1 = Product(title="TV", tags=['brand--acer', 'pixel--1920x1080'])
session.add(p1)
session.add_all([ProductParameter(product=p1, param_id=brand_params[0].id), ProductParameter(product=p1, param_id=pixel_params[0].id)])
session.commit()
此选项连接每个过滤器的 ProductParameter、FilterParameter 和 Filter。我认为您可以通过预取过滤器来优化它,例如
"brand"
,以消除到 Filter
的连接。
def filter_products_with_joins(session, filters):
res = session.query(Product)
for filter_name, filter_params in filters:
pp = aliased(ProductParameter)
fp = aliased(FilterParameter)
f = aliased(Filter)
res = res.join(pp, Product.id == pp.product_id)
res = res.join(fp, and_(fp.id == pp.param_id, fp.transcription.in_(filter_params.split(','))))
res = res.join(f, and_(f.id == fp.feature_id, f.transcription == filter_name))
return res.all()
这与加入类似,但可能更快并且更灵活/更易于阅读。我们不是链接连接,而是将每个过滤器的子查询与 AND 语句组合在一起。 IE。我想要的产品的 id 位于这些品牌的产品 id 列表中以及带有这些像素的产品 id 列表中,等等。
def filter_products_with_subqueries(session, filters):
subs = []
for filter_name, filter_params in filters:
res = session.query(Product.id)
pp = ProductParameter
fp = FilterParameter
f = Filter
res = res.join(pp, Product.id == pp.product_id)
res = res.join(fp, and_(fp.id == pp.param_id, fp.transcription.in_(filter_params.split(','))))
res = res.join(f, and_(f.id == fp.feature_id, f.transcription == filter_name))
subs.append(res.subquery())
return session.query(Product).filter(*[Product.id.in_(sub) for sub in subs]).all()
这更容易查询,但更难维护。在这种情况下,我们将过滤器/过滤器参数压缩为标签(过滤器和过滤器参数的组合)。我们使用标签设置产品,然后使用标签列表查询这些标签。 IE。我想要其标签具有这些品牌标签并且其标签具有这些像素标签等的产品。
def filter_products_with_tags(session, filters):
tag_groups = []
for filter_name, filter_params in filters:
tag_groups.append([f'{filter_name}--{filter_param}' for filter_param in filter_params.split(',')])
return session.query(Product).filter(*[Product.tags.overlap(tag_group) for tag_group in tag_groups]).all()
测试和主要调用。
def query(engine):
with Session(engine) as session:
for func in (filter_products_with_joins, filter_products_with_subqueries, filter_products_with_tags):
assert len(func(session, [('brand', 'acer,asus')])) == 1
assert len(func(session, [('brand', 'acer,asus'), ('pixel', '1920x1080')])) == 1
assert len(func(session, [('brand', 'acer,asus,sony'), ('pixel', '1920x1080')])) == 1
assert len(func(session, [('brand', 'acer,asus,sony'), ('pixel', '480x360')])) == 0
assert len(func(session, [('brand', 'acer,asus,sony'), ('pixel', '1920x1080,480x360')])) == 1
def main():
engine = get_engine(os.environ)
populate(engine)
query(engine)
if __name__ == '__main__':
main()