在 PyTorch 中注册一个 dict 对象

问题描述 投票:0回答:1

我以为这是一个简单的问题,但我找不到答案。

我希望使用模型 state_dict 保存/加载 pytorch 模块的成员变量。我可以在 init 中使用以下行来完成此操作。

        self.register_buffer('loss_weight', torch.tensor(loss_weight))

但是如果loss_weight是一个dict对象怎么办?允许吗?如果是这样,我怎样才能将它转换为张量?

尝试时,我收到错误“无法推断 dict 的 dtype。”

python pytorch state-dict
1个回答
0
投票

根据 docs

name
参数必须是字符串,
tensor
参数必须是 pytorch 张量。

如果您有缓冲区字典,您可以考虑使用专用的

nn.Module
来实现此目的。像这样的东西:

class BufferDict(nn.Module):
    def __init__(self, input_dict):
        super().__init__()
        for k,v in input_dict.items():
            self.register_buffer(k, v)
            
input_dict = {'a' : torch.randn(4), 'b' : torch.randn(5)}

bd = BufferDict(input_dict)
bd.state_dict()
> OrderedDict([('a', tensor([ 0.1908,  1.6965, -0.3710,  0.4551])),
               ('b', tensor([-0.6943, -0.0534,  0.1779,  1.3607, -0.2236]))])
© www.soinside.com 2019 - 2024. All rights reserved.