Streamlit 的 st.cache_resource 装饰器是加速 Streamlit 应用程序的关键。
在如下函数中:
@st.cache_resource
def slow_function(inputA, input_b):
...
return something
它可以通过记忆来加速代码。
但是,这依赖于所有输入都是“可哈希的”。如果输入本身没有实现
__hash__
dunder 方法,那么用户可以提供哈希函数。
class myCustomType:
...
@st.cache_resource(hash_funcs={myCustomType: lambda obj: ___})
def slow_function(input_a:int, input_b:myCustomType):
...
return something
我的问题是可以使用什么“通用”哈希函数,特别是对于 pandas 或 Polars 数据帧等输入。
我已经尝试过:
hash_funcs={pl.DataFrame: lambda obj: id(obj)} # Not stable across page re-execution
hash_funcs={pl.DataFrame: lambda obj: f'{obj.shape} {obj.schema}} # Not confident it's unique
hash_funcs={pl.DataFrame: lambda obj: hash_all_cells(obj) } # too slow
这里还有其他可以应用的解决方案吗?
数据帧不应该是输入。输入将是日期、文件或来自用户界面的任何内容。我猜您想缓存为回答 Q1 b/c 加载的数据帧,相同的数据可能对 Q2 有用。为此,您需要使用
globals()
创建自己的数据帧缓存器。
它可能看起来像这样:
def load_cache_df(filepath:str)->pl.DataFrame:
if "df_cache" not in globals():
globals['df_cache']={}
if filepath not in globals()['df_cache']:
globals()['df_cache'][filepath]=pl.read_parquet(filepath)
return globals()['df_cache'][filepath]
您可能需要另一个函数来根据 UI 输入推断文件路径,这很好,并且可能会被 Streamlit 缓存,因为这些文件路径可能是可哈希的。
那么你就会有类似的东西
@st.cache_resource
def derive_file_path(inputA, input_b):
return f"{inputA}/{input_b} # or whatever it is
@st.cache_resource
def slow_function(inputA, input_b):
df = load_cache_df(derive_file_path(inputA, input_b))
return df.group_by(inputA).agg(pl.col('cool_data').sum())