我以为这是一个简单的问题,但我找不到答案。
我希望使用模型 state_dict 保存/加载 pytorch 模块的成员变量。我可以在 init 中使用以下行来完成此操作。
self.register_buffer('loss_weight', torch.tensor(loss_weight))
但是如果loss_weight是一个dict对象怎么办?允许吗?如果是这样,我怎样才能将它转换为张量?
尝试时,我收到错误“无法推断 dict 的 dtype。”
根据 docs,
name
参数必须是字符串,tensor
参数必须是 pytorch 张量。
如果您有缓冲区字典,您可以考虑使用专用的
nn.Module
来实现此目的。像这样的东西:
class BufferDict(nn.Module):
def __init__(self, input_dict):
super().__init__()
for k,v in input_dict.items():
self.register_buffer(k, v)
input_dict = {'a' : torch.randn(4), 'b' : torch.randn(5)}
bd = BufferDict(input_dict)
bd.state_dict()
> OrderedDict([('a', tensor([ 0.1908, 1.6965, -0.3710, 0.4551])),
('b', tensor([-0.6943, -0.0534, 0.1779, 1.3607, -0.2236]))])