Streamlit 在cache_data 函数中不显示进度条

问题描述 投票:0回答:1

使用streamlit,我将重函数设置为

cache_data
类型以避免重新计算。由于该功能很耗时,我还想在其中创建一个进度条。但是,我发现它无法与指定的
cache_data
一起使用。

以下是MWE

import streamlit as st
from time import sleep

@st.cache_data(show_spinner = False)
def showProgressBar():
    cur = 0
    total = 100
    my_bar = st.progress(cur / total, text = "%d / %d" % (cur, total))
    while cur < total:
        sleep(0.05)
        cur = cur + 1
        my_bar.progress(cur / total, text = "%d / %d" % (cur, total))
    my_bar.empty()

### Main ###
st.set_page_config(page_title="Test Progress In Cache Function", page_icon=":bar_chart:",layout="wide")
st.title(" :bar_chart: Test")
showProgressBar()
st.text('Test Finish')

结果,这样的代码永远不会显示栏。但如果我注释掉

@st.cache_data
行,进度条就会按预期工作。

This thread 中提到了关于

progress
st.cache
的类似问题,解决方法似乎与
suppress_st_warning = True
有关,但是随着
st.cache
的贬值,此参数似乎不再适用于
st.cache_data

有人可以在这里提供一些帮助吗?

caching progress-bar streamlit
1个回答
0
投票

如果您的繁重计算是

sleep(0.05)
,这里有一个方法。这段代码从进度条处理中提取出繁重的计算,确保
cache_data
装饰器只能用于非 Streamlit 操作:

import streamlit as st
from time import sleep
from random import randint

@st.cache_data(show_spinner = False)
def do_heavy_calc(n):
    print(f"First time seeing {n}")
    sleep(0.05)


def showProgressBar():
    cur = 0
    total = 100
    my_bar = st.progress(cur / total, text = "%d / %d" % (cur, total))
    while cur < total:
        n = randint(1, 100)
        # Add a random number since I suppose you don't want to run
        # the exact same function everytime(?)
        do_heavy_calc(n)
        cur = cur + 1
        my_bar.progress(cur / total, text = "%d / %d" % (cur, total))
    my_bar.empty()

### Main ###
st.set_page_config(page_title="Test Progress In Cache Function", page_icon=":bar_chart:",layout="wide")
st.title(" :bar_chart: Test")
showProgressBar()
st.text('Test Finish')

但是,如果您想跟踪仅调用一次的一个长函数,则可以将其转换为生成器,生成

do_heavy_calc
函数当前所处的当前步骤。不过我不推荐这样做(感觉有点老套),但它在我的测试中似乎工作得很好。

有两点需要注意:

  • cache_data
    cache_resource
    取代,因为
    cache_data
    提高
    streamlit.runtime.caching.cache_errors.UnserializableReturnValueError
  • 我在 2 个输入参数前面加了一个下划线 (
    _
    ):这是为了确保它们不会用于缓存目的。请参阅此 streamlit 文档页面的“排除输入参数”部分。
import streamlit as st
from time import sleep

@st.cache_resource(show_spinner = False)
def do_heavy_calc(_cur, _total):
    while _cur < _total:
        # yield where you are so that the progress bar on the outside can
        # keep track
        yield _cur
        sleep(0.05)
        _cur = _cur + 1

def showProgressBar():
    min_ = 0
    max_ = 100
    my_bar = st.progress(min_ / max_, text = "%d / %d" % (min_, max_))
    for c in do_heavy_calc(0, max_):
        my_bar.progress(c / max_, text = "%d / %d" % (c, max_))
    my_bar.empty()

### Main ###
st.set_page_config(page_title="Test Progress In Cache Function", page_icon=":bar_chart:",layout="wide")
st.title(" :bar_chart: Test")
showProgressBar()
st.text('Test Finish')
© www.soinside.com 2019 - 2024. All rights reserved.