model.train() 在 PyTorch 中做什么?

问题描述 投票:0回答:6

它在

forward()
中调用
nn.Module
吗?我认为当我们调用模型时,正在使用
forward
方法。 为什么我们需要指定train()?

python machine-learning deep-learning pytorch
6个回答
303
投票

model.train()
告诉您的模型您正在训练该模型。这有助于为 Dropout 和 BatchNorm 等层提供信息,这些层旨在在训练和评估期间表现不同。例如,在训练模式下,BatchNorm 更新每个新批次的移动平均值;然而,对于评估模式,这些更新被冻结。

更多详情:

model.train()
设置训练模式 (参见源代码)。您可以致电
model.eval()
model.train(mode=False)
告知您正在测试。 期望
train
函数来训练模型有点直观,但它并没有这样做。它只是设置模式。


97
投票

这是

nn.Module.train()
的代码:

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

这是

nn.Module.eval()
的代码:

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

默认情况下,

self.training
标志设置为
True
,即模块默认处于训练模式。当
self.training
False
时,模块处于相反状态,即 eval 模式。

在最常用的层中,只有

Dropout
BatchNorm
关心该标志。


41
投票
model.train()
model.eval()
training 模式下设置模型,即

BatchNorm
层使用每批次统计信息
Dropout
层已激活 etc
将模型设置为 evaluation(推理)模式,即

BatchNorm
层使用运行统计信息
Dropout
层已停用等
相当于
model.train(False)

注意: 这些函数调用都不会向前/向后传递。他们告诉模型如何行动何时运行。

这很重要,因为某些模块(层)(例如

Dropout
BatchNorm
)被设计为在训练与推理期间表现不同,因此如果在错误的模式下运行,模型将产生意外的结果。


17
投票

有两种方式让模型知道你的意图,即你想要训练模型还是想要使用模型进行评估。 在

model.train()
的情况下,模型知道它必须学习各层,当我们使用
model.eval()
时,它表明模型不需要学习任何新内容,并且该模型用于测试。
model.eval()
也是必要的,因为在pytorch中,如果我们使用batchnorm,并且在测试过程中,如果我们只想传递单个图像,如果未指定
model.eval()
,pytorch会抛出错误。


3
投票

考虑以下模型

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GraphNet, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training) #Look here
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

这里,

dropout
的功能在不同的操作模式下有所不同。正如您所看到的,它仅在
self.training==True
时才起作用。因此,当您输入
model.train()
时,模型的前向函数将执行 dropout,否则不会(例如当
model.eval()
model.train(mode=False)
时)。


2
投票

当前的官方文档声明如下:

这仅对某些模块产生[原文]影响。请参阅特定模块的文档,了解其在培训/评估模式下的行为详细信息(如果它们受到影响),例如Dropout、BatchNorm 等

© www.soinside.com 2019 - 2024. All rights reserved.