Sqlalchemy replacement_traverse 不修改查询

问题描述 投票:0回答:1

我的理解是,传递到

replacement_traverse
的包装函数应该用别名的 tablesampled 对象替换该元素(在本例中为
FooModel
)。然而,返回的查询与传入的查询相同,即使输出指示
if
上的一次命中以及
elif
上的两次命中。我希望替换后,我的查询会看到原始 sql 从使用
foo
更改为
foo_1
(采样对象)。

def test():
    """_summary_
    https://www.mail-archive.com/[email protected]/msg25258.html
    """
    import time
    import sqlalchemy as sa
    from sqlalchemy import Table
    from sqlalchemy.sql.visitors import replacement_traverse
    global session
    
    def wrap():
        _foo_sampled = aliased(foomodel.FooModel, tablesample(foomodel.FooModel, 50))
        print(_foo_sampled)
        def replace(element, **kw):
            """https://www.mail-archive.com/[email protected]/msg25258.html"""
            if isinstance(element, Table) and element.name == foomodel.FooModel.__tablename__:
                print(f"replacing if-branch {element} {type(element)} {_foo_sampled} {type(_foo_sampled)}")
                return _foo_sampled
            elif foomodel.FooModel.__table__.c.contains_column(element):  # replace columns in the table
                print("replacing elif-branch")
                return _foo_sampled.__table__.c[element.key]
            else:
                print(type(element))
        return replace
    wrapper = wrap()
    
    q = (
        sa.select(sa.func.count(foomodel.FooModel.id), barmodel.BarModel.name)
        .join(
            foomodel.foo_baz_association_table,
            foomodel.foo_baz_association_table.c.foo_id == foomodel.FooModel.id,
        )
        .join(barmodel.BazModel, barmodel.BazModel.id == foomodel.foo_baz_association_table.c.baz_id)
        .join(barmodel.BarModel, barmodel.BazModel.bar_id == barmodel.BarModel.id)
        .where(barmodel.BarModel.name != sa.null())
        .group_by(barmodel.BarModel.name)
    )
    
    print(q)
    
    tic = time.time()
    results0 = session.execute(q).all()
    toc0 = time.time() - tic

    new_q = replacement_traverse(q, {}, wrapper)
    print("\n"*3)
    print(new_q)
    print("\n"*3)
    
    tic = time.time()
    results2 = session.execute(new_q).all()
    toc2 = time.time() - tic
    print(f"{toc0=} {toc2=}")
    print(results0)
    print(results2) #expect to see 1/2 the results with tablesampling of 50

想要的

SELECT count(foo_1.id) AS count_1, bar.name 
FROM foo AS foo_1 TABLESAMPLE system(:system_1) JOIN 

实际

SELECT count(foo.id) AS count_1, bar.name 
FROM foo  JOIN
python-3.x sqlalchemy
1个回答
0
投票

不确定这是否是 sqlalchemy 中的错误,但解决方案是更改

_foo_sampled = aliased(foomodel.FooModel, tablesample(foomodel.FooModel, 50))

不使用

aliased
并仅传递
tablesample

的结果
_foo_sampled = tablesample(foomodel.FooModel, 50)
def sample_host_device(element, **kw):
    if isinstance(element, Table) and element.name == foomodel.FooModel.__tablename__:
        return _foo_sampled
© www.soinside.com 2019 - 2024. All rights reserved.