Pytorch 中每个 epoch 或每个 index_data 的打印损失

问题描述 投票:0回答:1

我是 Pytorch 的新手,所以请帮忙,我对训练数据集中计算 epoch_loss 的位置感到困惑:

def train(net,trainloader,epochs,use_gpu = True): 
    ..
    net.train()
    # Train the network
    for epoch in range(epochs): 
        print ("Epoch {}/{}".format(epoch+1, epochs))
        running_loss = 0.0
        running_corrects = 0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = net(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                        
            epoch_loss = running_loss/len(trainloader.dataset)
    
            print('Loss: {}'.format(epoch_loss))  

我必须打印 Loss 的地方,如代码或在每个纪元,如下所示

for epoch in range(epochs):
    for i, data in enumerate(trainloader, 0):
    ...
    epoch_loss = running_loss/len(trainloader.dataset)
    print('Loss: {}'.format(epoch_loss)) 
    
python-3.x pytorch
1个回答
0
投票

这取决于要打印的内容。 批次损失是指特定批次上模型的损失。 Epoch 损失是一个 epoch 中所有批次损失的平均值。 如果您想打印批次损失,请在每批次后打印

running_loss

def train(net,trainloader,epochs,use_gpu = True): 
    net.train()
    # Train the network
    for epoch in range(epochs): 
        print ("Epoch {}/{}".format(epoch+1, epochs))
        running_loss = 0.0
        running_corrects = 0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = net(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                        
            epoch_loss = running_loss/len(trainloader.dataset)
    
            print('Batch loss: {}'.format(loss.item()))

另一方面,如果您想打印纪元损失,则应该使用第二个选项

© www.soinside.com 2019 - 2024. All rights reserved.