如何使用 torch.hub.load 加载本地模型?

问题描述 投票:0回答:3

我需要避免从网上下载模型(由于安装机器的限制)。

这行得通,但它是从互联网上下载模型

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)

我已经将

.pth
文件和
hubconf.py
文件放在/tmp/文件夹中,并将我的代码更改为

model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')

但令我惊讶的是,它仍然从网上下载模型。我究竟做错了什么?如何在本地加载模型?

只是为了给你更多的细节,我在运行时有一个只读卷的 Docker 容器中完成所有这些,所以这就是下载新文件失败的原因。

python machine-learning pytorch torch torchvision
3个回答
7
投票

您可以采用两种方法在没有互联网连接的机器上获得可交付模型。

  1. 在普通机器上加载预训练模型的DeepLab,使用JIT编译器将其导出为图,并放入机器中。脚本很容易理解:

     # To export
     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
     traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
     traced_graph.save('DeepLab.pth')
    
     # To load
     model = torch.jit.load('DeepLab.pth').eval().to(device)
    

    在这种情况下,权重和网络结构被保存为计算图,因此您不需要任何额外的文件。

  2. 看看 torchvision 的 GitHub 存储库.

    DeepLabV3 有一个下载 URL,具有 Resnet101 骨干权重。

    您可以下载这些权重一次,然后使用带有 pretrained=False 标志的 torchvision 的 deeplab 并手动加载权重。

     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
     model.load_state_dict(torch.load('downloaded weights path'))
    

    考虑一下,在状态字典中可能有一个['state_dict']或一些类似的父键,您可以在其中使用:

     model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
    

4
投票
model_name='best.pt'
model = torch.hub.load(os.getcwd(), 'custom', source='local', path = model_name, force_reload = True)

这对我有用。默认源是 github。


0
投票

这锅适合我:

    # model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
    model_path = '~/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth'
    model = deeplabv3_resnet101(pretrained=True)
    model.load_state_dict(torch.load(model_path))
    model.eval()
© www.soinside.com 2019 - 2024. All rights reserved.