pytorch 中的自定义模型:可训练和不可训练参数

问题描述 投票:0回答:1

我想使用 Pytorch 创建一个自定义模型,其中我需要将输入与包含可训练和不可训练参数的矩阵相乘(我希望实现一个可训练的卡尔曼滤波器,具有自由和固定参数)。此外,这样的矩阵在多个条目中具有相同的参数。

但是我在训练中很挣扎(也许太多了!)...有什么解决方法吗?

class CustomModel(torch.nn.Module):
    def __init__(self,w0):
        super(CustomModel, self).__init__()
        self.w = torch.nn.Parameter(data = torch.tensor([w0], dtype=torch.float32, requires_grad=True))
    
        #self.matrix = torch.tensor(data = [[self.w, -1.],[-self.w, -1.]], dtype=torch.float32, requires_grad=True    This computes \partial_matrix(COST) --> BAD

        self.matrix_trainable = self.w*torch.tensor(data=[[0,1],[-1,0]], dtype=torch.float32,requires_grad=False)
        self.matrix = self.matrix_trainable - torch.eye(2)
    
    def forward(self, x):
        return self.matrix.matmul(x)

def loss(pred,y):
    return torch.mean((pred- y)**2)


my_model = CustomModel(w0=0.01)
optimizer = torch.optim.Adam(lr=0.01, params=my_model.parameters())

device = torch.device("cpu")
x = torch.ones(2).to(device)
y = torch.tensor(data=[2.,0.], dtype=torch.float32).to(device)


for k in range(10):

    optimizer.zero_grad()
    my_model.zero_grad()
    pred = my_model(x)
    cost = loss(pred,y)
    cost.backward()
    optimizer.step()

运行时错误:尝试第二次向后浏览图表(或者在释放张量后直接访问保存的张量)。当您调用 .backward() 或 autograd.grad() 时,保存的图表中间值将被释放。如果您需要第二次向后浏览图表或者需要在向后调用后访问保存的张量,请指定retain_graph=True。

火炬版本2.0.1

非常感谢!

马蒂亚斯

deep-learning pytorch neural-network kalman-filter custom-training
1个回答
0
投票

在您的代码中,我可以看到您的目的是将矩阵乘法分为两步。首先,您有一个可训练的

A
和不可训练的
B
。那么我建议以下实现

class CustomModel(torch.nn.Module):
    def __init__(self,w0):
        super(CustomModel, self).__init__()
        self.A = torch.nn.Parameter(data = torch.tensor([w0], dtype=torch.float32, requires_grad=True))

        B = torch.tensor([[1,0],[0,1]]) # or any values you want
        self.register_buffer("B", B, persistent=False)
    
    def forward(self, x):
        return self.B @ self.A @ x # or something like this


register_buffer
使得模型的属性(这里是参数)不会通过调用
model.parameters()
返回,所以
B
在训练过程中不会受到
optimizer.step()
的影响

© www.soinside.com 2019 - 2024. All rights reserved.