我需要避免从网上下载模型(由于安装机器的限制)。
这行得通,但它是从互联网上下载模型
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)
我已经将
.pth
文件和hubconf.py
文件放在/tmp/文件夹中,并将我的代码更改为
model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')
但令我惊讶的是,它仍然从网上下载模型。我究竟做错了什么?如何在本地加载模型?
只是为了给你更多的细节,我在运行时有一个只读卷的 Docker 容器中完成所有这些,所以这就是下载新文件失败的原因。
您可以采用两种方法在没有互联网连接的机器上获得可交付模型。
在普通机器上加载预训练模型的DeepLab,使用JIT编译器将其导出为图,并放入机器中。脚本很容易理解:
# To export
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
traced_graph.save('DeepLab.pth')
# To load
model = torch.jit.load('DeepLab.pth').eval().to(device)
在这种情况下,权重和网络结构被保存为计算图,因此您不需要任何额外的文件。
DeepLabV3 有一个下载 URL,具有 Resnet101 骨干权重。
您可以下载这些权重一次,然后使用带有 pretrained=False 标志的 torchvision 的 deeplab 并手动加载权重。
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
model.load_state_dict(torch.load('downloaded weights path'))
考虑一下,在状态字典中可能有一个['state_dict']或一些类似的父键,您可以在其中使用:
model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
model_name='best.pt'
model = torch.hub.load(os.getcwd(), 'custom', source='local', path = model_name, force_reload = True)
这对我有用。默认源是 github。
这锅适合我:
# model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
model_path = '~/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth'
model = deeplabv3_resnet101(pretrained=True)
model.load_state_dict(torch.load(model_path))
model.eval()