即使在评估模式下也有必要调用torch.no_grad()吗?

问题描述 投票:0回答:1

我正在学习

pytorch
。在代码示例中,模型通过使用
model.train()
model.eval()
模式在训练和测试之间切换。我知道必须这样做才能在测试时停用训练特定行为,例如退出、归一化和梯度计算。

在许多这样的例子中,他们还使用

torch.no_grad()
,我理解这是一种明确要求停止梯度计算的方式。

我的问题是,如果在

model.eval()
模式下停止梯度计算那为什么我们还必须设置
torch.no_grad()

下面是同时使用

model.eval()
torch.no_grad()
的代码示例

class QuertyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.input = nn.Linear(2, 8)
        
        self.hl1 = nn.Linear(8, 16)
        self.hl2 = nn.Linear(16, 32)
        self.hl3 = nn.Linear(32, 16)
        self.hl4 = nn.Linear(16, 8)
        
        self.output = nn.Linear(8, 3)
        
    def forward(self, x):
        x = F.relu(self.input(x))
        x = F.relu(self.hl1(x))
        x = F.relu(self.hl2(x))
        x = F.relu(self.hl3(x))
        x = F.relu(self.hl4(x))
        
        return self.output(x)

训练模型的函数

def train_model(model, lr, num_epochs):
    
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    train_acc, train_loss, test_acc, test_loss = [], [], [], []
    for epoch in range(num_epochs):
        print(f"{epoch+1}/{num_epochs}")
        
        model.train() # switch to training mode
        
        batch_acc, batch_loss = [], []
        for X, y in train_loader:
            # forward pass
            y_hat = model(X)
            loss = loss_function(y_hat, y)
            
            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            matches = torch.argmax(y_hat, axis=1) == y # true/false
            matches = matches.float() # 0/1
            
            batch_acc.append(100*torch.mean(matches))
            batch_loss.append(loss.item())
        
        train_acc.append(np.mean(batch_acc))
        train_loss.append(np.mean(batch_loss))
        
        model.eval() # switch to evaluation mode
        
        X, y = next(iter(test_loader))
        with torch.no_grad():
            y_hat = model(X)
        
        test_acc.append(100*(torch.mean(y_hat) == y).float())    
        test_loss.append(loss_function(y_hat, y).item())
        
    return train_acc, train_loss, test_acc, test_loss
python pytorch
1个回答
0
投票

.train()
.eval()
仅修改网络某些层的某些行为(如BatchNorm、Dropout等),它们没有明确允许或阻止模型的梯度,因此为了进行推理,您应该调用
torch.no_grad() 
停止存储中间值

© www.soinside.com 2019 - 2024. All rights reserved.