对于streamlit的st.cache_resource中不可散列的项目来说,良好的“hash_func” - 主要是数据帧?

问题描述 投票:0回答:1

Streamlit 的 st.cache_resource 装饰器是加速 Streamlit 应用程序的关键。

在如下函数中:

@st.cache_resource
def slow_function(inputA, input_b): 
   ...
   return something

它可以通过记忆来加速代码。

但是,这依赖于所有输入都是“可哈希的”。如果输入本身没有实现

__hash__
dunder 方法,那么用户可以提供哈希函数。


class myCustomType:
   ...

@st.cache_resource(hash_funcs={myCustomType: lambda obj: ___})
def slow_function(input_a:int, input_b:myCustomType): 
   ...
   return something

我的问题是可以使用什么“通用”哈希函数,特别是对于 pandas 或 Polars 数据帧等输入。

我已经尝试过:

hash_funcs={pl.DataFrame: lambda obj: id(obj)} # Not stable across page re-execution
hash_funcs={pl.DataFrame: lambda obj: f'{obj.shape} {obj.schema}} # Not confident it's unique
hash_funcs={pl.DataFrame: lambda obj: hash_all_cells(obj) } # too slow

这里还有其他可以应用的解决方案吗?

python hash streamlit python-polars memoization
1个回答
0
投票

数据帧不应该是输入。输入将是日期、文件或来自用户界面的任何内容。我猜您想缓存为回答 Q1 b/c 加载的数据帧,相同的数据可能对 Q2 有用。为此,您需要使用

globals()
创建自己的数据帧缓存器。

它可能看起来像这样:

def load_cache_df(filepath:str)->pl.DataFrame:
    if "df_cache" not in globals():
        globals['df_cache']={}
    if filepath not in globals()['df_cache']:
        globals()['df_cache'][filepath]=pl.read_parquet(filepath)
    return globals()['df_cache'][filepath]

您可能需要另一个函数来根据 UI 输入推断文件路径,这很好,并且可能会被 Streamlit 缓存,因为这些文件路径可能是可哈希的。

那么你就会有类似的东西

@st.cache_resource
def derive_file_path(inputA, input_b):
   return f"{inputA}/{input_b} # or whatever it is



@st.cache_resource
def slow_function(inputA, input_b): 
   df = load_cache_df(derive_file_path(inputA, input_b))
   return df.group_by(inputA).agg(pl.col('cool_data').sum())
© www.soinside.com 2019 - 2024. All rights reserved.