如何将BATN的BatchNorm重量转换为pytorch BathNorm?

问题描述 投票:0回答:1

可以从pycaffe中读取Caffe模型的BathNorm和Scale权重,它们是BatchNorm中的三个权重和Scale中的两个权重。我尝试将这些权重复制到pytorch BatchNorm,代码如下:

if 'conv3_final_bn' == name:
    assert len(blobs) == 3, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.running_mean'] = blobs[0].data
    torch_mod['conv3_final_bn.running_var'] = blobs[1].data
elif 'conv3_final_scale' == name:
    assert len(blobs) == 2, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.weight'] = blobs[0].data
    torch_mod['conv3_final_bn.bias'] = blobs[1].data

两个BatchNorm的行为不同。我还尝试设置conv3_final_bn.weight = 1和conv3_final_bn.bias = 0来验证caffe的BN层,结果也不匹配。

我该如何处理错误的匹配?

caffe pytorch pycaffe
1个回答
0
投票

得到它了!在BatchNorm中,caffe还有第三个参数。代码应该是:

if 'conv3_final_bn' == name:
    assert len(blobs) == 3, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.running_mean'] = blobs[0].data / blobs[2].data[0]
    torch_mod['conv3_final_bn.running_var'] = blobs[1].data / blobs[2].data[0]
elif 'conv3_final_scale' == name:
    assert len(blobs) == 2, '{} layer blob count: {}'.format(name, len(blobs))
    torch_mod['conv3_final_bn.weight'] = blobs[0].data
    torch_mod['conv3_final_bn.bias'] = blobs[1].data
© www.soinside.com 2019 - 2024. All rights reserved.