如何将使用 nn.DataParallel 训练和保存的模型中的检查点加载到不使用 nn.DataParallel 的模型上?

问题描述 投票:0回答:1

如何将使用 nn.DataParallel 训练和保存的模型中的检查点加载到不使用 nn.DataParallel 的模型上?我尝试删除“模块”。来自 state_dict,但我现在遇到了不同的错误。这是 ResNet-50 检查点的链接。

from torchvision.models import ResNet50_Weights, resnet50

# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)

state_dict = checkpoint['state_dict']

# creating new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

这会产生错误

RuntimeError: Error(s) in loading state_dict for ResNet:
    Unexpected key(s) in state_dict: "bn1.aux_bn.weight", "bn1.aux_bn.bias", "bn1.aux_bn.running_mean", "bn1.aux_bn.running_var", "bn1.aux_bn.num_batches_tracked", "layer1.0.bn1.aux_bn.weight", "layer1.0.bn1.aux_bn.bias", "layer1.0.bn1.aux_bn.running_mean", "layer1.0.bn1.aux_bn.running_var", "layer1.0.bn1.aux_bn.num_batches_tracked", "layer1.0.bn2.aux_bn.weight", "layer1.0.bn2.aux_bn.bias", "layer1.0.bn2.aux_bn.running_mean", "layer1.0.bn2.aux_bn.running_var", "layer1.0.bn2.aux_bn.num_batches_tracked", "layer1.0.bn3.aux_bn.weight", "layer1.0.bn3.aux_bn.bias", "layer1.0.bn3.aux_bn.running_mean", "layer1.0.bn3.aux_bn.running_var", "layer1.0.bn3.aux_bn.num_batches_tracked", "layer1.0.downsample.1.aux_bn.weight", "layer1.0.downsample.1.aux_bn.bias", "layer1.0.downsample.1.aux_bn.running_mean", "layer1.0.downsample.1.aux_bn.running_var", "layer1.0.downsample.1.aux_bn.num_batches_tracked", "layer1.1.bn1.aux_bn.weight", "layer1.1.bn1.aux_bn.bias", "layer1.1.bn1.aux_bn.running_mean", "layer1.1.bn1.aux_bn.running_var", "layer1.1.bn1.aux_bn.num_batches_tracked", "layer1.1.bn2.aux_bn.weight", "layer1.1.bn2.aux_bn.bias", "layer1.1.bn2.aux_bn.running_mean", "layer1.1.bn2.aux_bn.running_var", "layer1.1.bn2.aux_bn.num_batches_tracked", "layer1.1.bn3.aux_bn.weight", "layer1.1.bn3.aux_bn.bias", 

正常加载是这样的

# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)

state_dict = checkpoint['state_dict']

model.load_state_dict(state_dict)

给出错误

Unexpected key(s) in state_dict: "module.conv1.weight",

RuntimeError: Error(s) in loading state_dict for ResNet:
    Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", ...

Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.bn1.aux_bn.weight", "module.bn1.aux_bn.bias", "module.bn1.aux_bn.running_mean", "module.bn1.aux_bn.running_var", "module.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.bn1.num_batches_tracked", "module.layer1.0.bn1.aux_bn.weight", "module.layer1.0.bn1.aux_bn.bias", "module.layer1.0.bn1.aux_bn.running_mean", "module.layer1.0.bn1.aux_bn.running_var", "module.layer1.0.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv2.weight", "module.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.bn2.num_batches_tracked", "module.layer1.0.bn2.aux_bn.weight", "module.layer1.0.bn2.aux_bn.bias", "module.layer1.0.bn2.aux_bn.running_mean", "module.layer1.0.bn2.aux_bn.running_var", "module.layer1.0.bn2.aux_bn.num_batches_tracked", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias", "module.layer1.0.bn3.running_mean", "module.layer1.0.bn3.running_var", "module.layer1.0.bn3.num_batches_tracked", "module.layer1.0.bn3.aux_bn.weight", "module.layer1.0.bn3.aux_bn.bias", "module.layer1.0.bn3.aux_bn.running_mean", "module.layer1.0.bn3.aux_bn.running_var", "module.layer1.0.bn3.aux_bn.num_batches_tracked", "module.layer1.0.downsample.0.weight", "module.layer1.0.downsample.1.weight", "module.layer1.0.downsample.1.bias", "module.layer1.0.downsample.1.running_mean", "module.layer1.0.downsample.1.running_var", "module.

非常感谢。

python deep-learning pytorch torch torchvision
1个回答
0
投票

您做了正确的事情,删除了

"module."
前缀,但剩下的问题来自于这个
resnet50
模型是使用自定义标准化层初始化的,您可以在 here 看到它的使用。该层在
aux_bn.py
中定义,并产生您看到的类型为
"bn*.aux_bn"
的键。

代码应该在正确的初始化下运行:

checkpoint = torch.load(checkpoint_path)
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

model = resnet50(num_classes=1_000, norm_layer=MixBatchNorm2d)
model.load_state_dict(state_dict)
© www.soinside.com 2019 - 2024. All rights reserved.