意外的打印输出会干扰 PyTorch 训练运行中的 tqdm 进度条

问题描述 投票:0回答:1

我试图了解使用

tqdm
的进度条是如何准确工作的。我有一些代码如下所示:

import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

load_data()
manual_transforms = transforms.Compose([])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders()

# them within the main function I have placed the train function that exists in the `engine.py` file
def main():

      results = engine.train(model=model,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        optimizer=optimizer,
        loss_fn=loss_fn,
        epochs=5,
        device=device)

并且

engine.train()
函数包含以下代码
for epoch in tqdm(range(epochs)):
然后,对每个批次进行训练以可视化训练进度。每次 tqdm 运行每个步骤时,它还会打印以下语句:

print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

最后,我的问题是为什么会发生这种情况。主函数如何访问这些全局语句以及如何避免在每个循环中打印所有内容?

python pytorch tqdm
1个回答
1
投票

您注意到的实际上与

tqdm
无关,而是与 PyTorch 的内部工作原理(特别是
DataLoader
num_workers
属性)和 Python 的底层
multiprocessing
框架有关。这是一个应该重现您的问题的最小工作示例:

from contextlib import suppress
from multiprocessing import set_start_method
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
print("torch version:", torch.__version__)

class DummyData(Dataset):
    def __len__(self): return 256
    def __getitem__(self, i): return i

def main():
    for batch in tqdm(DataLoader(DummyData(), batch_size=16, num_workers=4)):
        pass  # Do something
    
if __name__ == "__main__":
    # Enforce "spawn" method (e.g. on Linux) for subprocess creation to
    # reproduce problem (suppress error for reruns in same interpreter)
    with suppress(RuntimeError): set_start_method("spawn")
    main()

如果运行这段代码,您应该会看到 PyTorch 版本号被打印了 4 次,弄乱了您的

tqdm
进度条。这个数字与
num_workers
相同并非巧合(您可以通过更改此数字轻松检查)。

发生的情况如下:

  • 如果
    num_workers
    > 0,则为工作人员启动子流程。
  • 在 Windows 和 macOS 上,这些子进程默认使用“spawn”方法启动(在 Linux 上,可以强制执行此方法来重现您的观察结果,我已使用
    set_start_method()
    完成了这一操作)。
  • “spawn”方法将为每个子进程启动一次主脚本,执行所有不受
    if __name__ == "__main__":
    块保护的行。这包括您在脚本顶部的
    print()
    调用。

该行为以及潜在的缓解措施已记录在此处。我想,对你有用的一个是:

将大部分主脚本代码包装在

if __name__ == '__main__':
块中,以确保它不会再次运行

所以,要么

  1. print()
    调用移至
    if __name__ == '__main__':
    块的开头,
  2. print()
    调用移至
    main()
    函数的开头,或者
  3. 删除
    print()
    呼叫。

或者,但这可能不是您想要的,您可以设置

num_workers=0
,这将完全禁用
multiprocessing
的底层使用(但这样您也会失去并行化的好处)。请注意,您可能还应该将其他函数调用(例如
load_data()
)移至
if __name__ == '__main__':
块或
main()
函数中,以避免多次意外执行。

© www.soinside.com 2019 - 2024. All rights reserved.