通过 CNN 网络中正确创建的两个不同分支进行反向传播

问题描述 投票:0回答:1

我有一个 CNN 网络,我在其中创建了两个不同的分支,输出将一个用于分类,一个用于回归
现在想要通过两个分支进行反向传播,以便模型可以很好地学习

        self.gender_branch = nn.Linear(int(img_size * 0.25), 2)  # Binary classification for gender (0 or 1)
        self.age_branch = nn.Linear(int(img_size * 0.25), 1)  # Regression for age (between 0 and 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Separate optimizers for each branch
optimizer_gender = torch.optim.Adam(model.gender_branch.parameters(), lr=0.001)
optimizer_age = torch.optim.Adam(model.age_branch.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
criterion_age = nn.MSELoss()
for e in range(epochs):
    epoch_start_time = time.time()
    running_loss_gender, running_loss_age = 0, 0
    accuracy_gender, accuracy_age = 0, 0

    model.train()
    for images, labels in train_dataloader:
        images = images.to(device)
        labels_gender = labels[:, 0, 0].long().to(device)
        labels_age = labels[:, 0, 1].view(-1, 1).to(device)

        optimizer.zero_grad()


        output_gender, output_age = model(images)

        # Gender loss and accuracy
        loss_gender = criterion(output_gender, labels_gender)
        running_loss_gender += loss_gender.item()
        accuracy_gender += torch.sum(torch.argmax(output_gender, dim=1) == labels_gender).item()

        # Age regression loss (MSE loss)
        loss_age = criterion_age(output_age, labels_age)
        running_loss_age += loss_age.item()
        accuracy_age += torch.sum(torch.abs(output_age - labels_age) < 0.10).item()

        # Backward pass for gender branch
        loss_gender.backward(retain_graph=True)
 

        # # Backward pass for age branch
        loss_age.backward(retain_graph=True)

        optimizer.step()

现在对于两个分支,我希望模型分别收敛,以便我可以分别为两个分支保存 pytorch 模型,
这是通过整个基础模型以及两个分支分别反向传播的正确方法吗? 上面的代码只保存了2kb的模型并且偏向于一类如何纠正它?

deep-learning pytorch conv-neural-network
1个回答
0
投票
通常,您执行此操作的方式只是简单的

loss = loss_gender + loss_age

loss.backward()
,而无需尝试维护单独的模型。

只有最后一个性别分支和年龄分支不同,一旦训练完成,您就可以从模型中删除其中任何一个。

© www.soinside.com 2019 - 2024. All rights reserved.