tqdm不再继续前进,而是重新开始

问题描述 投票:0回答:1

我有以下工作代码。 我想在新行上打印所有第 5000 个项目的每一项损失 (% 5000) 但进度条应该继续 它的方式一劳永逸(应该只打印一次)并显示总进度。我该如何修改代码?

import torch
from math import tanh,cos
from tqdm import tqdm
from time import sleep

batch, dim_in, dim_h, dim_out = 1, 100, 10, 1

input_X = torch.randn(batch, dim_in)
output_Y = torch.randn(batch, dim_out)

SGD_model = torch.nn.Sequential(
    torch.nn.Linear(dim_in, dim_h),
    torch.nn.Tanh(),
    torch.nn.Linear(dim_h, dim_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

rate_learning = 0.01

optim = torch.optim.SGD(SGD_model.parameters(), lr=rate_learning, momentum=0.01)
    
for values in tqdm(range(1000)):
    pred_y = SGD_model(input_X)
    loss = loss_fn(pred_y, output_Y)
    if values % 100 == 0:
        print(values, loss.item())  
    optim.zero_grad()
    loss.backward()
    optim.step()

三个进度条而不是一个:

chatGPT 4 为我提供了这段代码,但仍然有 4 个进度条:

import torch
from math import tanh
from time import sleep

batch, dim_in, dim_h, dim_out = 32, 10, 5, 1

input_X = torch.randn(batch, dim_in)
output_Y = torch.randn(batch, dim_out)

SGD_model = torch.nn.Sequential(
    torch.nn.Linear(dim_in, dim_h),
    torch.nn.Tanh(),
    torch.nn.Linear(dim_h, dim_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

rate_learning = 0.0001

optim = torch.optim.SGD(SGD_model.parameters(), lr=rate_learning, momentum=0.4)

# Define the training loop with a range
epochs = 1000
for epoch in range(epochs):
    progress = (epoch + 1) / epochs
    bar_length = 50
    block = int(round(bar_length * progress))
    text = f"Epoch {epoch+1}/{epochs} [{'#' * block + '-' * (bar_length - block)}] {100 * progress:.2f}%"
    pred_y = SGD_model(input_X)
    loss = loss_fn(pred_y, output_Y)
    if epoch % 100 == 0:
        text += f', Loss: {loss.item()}'
    print(text, end='\r')
    optim.zero_grad()
    loss.backward()
    optim.step()

print("\nWeights:", SGD_model[0].weight)
print("Bias:", SGD_model[0].bias)
print("Bias:", SGD_model[2].bias) 
python deep-learning printing epoch tqdm
1个回答
0
投票

而不是

print( ... )
,更喜欢
tqdm.write( ... )

另外,考虑用 记录器。 那么您可以信赖 tqdm 上下文管理器。 来自文档:

    with logging_redirect_tqdm():
        for i in trange(9):
            if i == 4:
                LOG.info("console logging redirected to `tqdm.write()`")
© www.soinside.com 2019 - 2024. All rights reserved.