打印火车统计数据

问题描述 投票:0回答:1

为了训练 cifar100 数据集,我找到了这个函数 train,虽然我是 Pytorch 的新手,但我想了解值 10000,因为当我更改它时,损失会发生变化

def train(net,trainloader,epochs,use_gpu = True):
    net.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")

    # Train the network
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            images, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 10000))
                running_loss = 0.0
python pytorch
1个回答
0
投票

代码将 100 个小批量的损失相加:

running_loss += loss
。每 100 个小批量 (
(i + 1) % 100 == 0
),您需要将
running_loss
除以 100 才能获得平均值。然后,代码重置
running_loss
(
running_loss=0
),然后开始再次累加接下来 100 个小批量的损失。

代码有错误;它应该除以“100”,而不是“10, 000”,因为每次都会累积 100 个值。

© www.soinside.com 2019 - 2024. All rights reserved.