具有嵌入层的PyTorch朴素单标签分类随机失败

问题描述 投票:3回答:1

我是PyTorch的新手,我正在尝试嵌入层。

我写了一个天真的分类任务,其中所有输入都相等,所有标签都设置为1.0。因此,我希望模型能够快速学习以预测1.0。

输入始终为0,它被送入nn.Embedding(1,32)层,然后是nn.Linear(32,1)和nn.Relu()。

但是,会出现意外和不期望的行为:在运行代码的不同时间,培训结果会有所不同。例如,

  • 将随机种子设置为10,模型收敛:损失减少,模型总是预测1.0
  • 将随机种子设置为1111,模型不收敛:损失不减少,模型总是预测0.5。在这些情况下,参数不会更新

这是最小的可复制代码:

from torch.nn import BCEWithLogitsLoss
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.utils.data import Dataset
import torch


class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.vgg_fc = nn.Linear(32, 1)
        self.relu = nn.ReLU()
        self.embeddings = nn.Embedding(1, 32)

    def forward(self, data):
        emb = self.embeddings(data['index'])
        return self.relu(self.vgg_fc(emb))


class MyDataset(Dataset):

    def __init__(self):
        pass
    def __len__(self):
        return 1000
    def __getitem__(self, idx):
        return {'label': 1.0, 'index': 0}


def train():
    model = MyModel()
    db = MyDataset()
    dataloader = DataLoader(db, batch_size=256, shuffle=True, num_workers=16)

    loss_function = BCEWithLogitsLoss()
    optimizer_rel = optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(50):
        for i_batch, sample_batched in enumerate(dataloader):

            model.zero_grad()
            out = model({'index': Variable(sample_batched['index'])})

            labels = Variable(sample_batched['label'].type(torch.FloatTensor).view(sample_batched['label'].shape[0], 1))

            loss = loss_function(out, labels)
            loss.backward()
            optimizer_rel.step()
            print 'Epoch:', epoch, 'batch', i_batch, 'Tr_Loss:', loss.data[0]
    return model


if __name__ == '__main__':

    # please, try seed 10 (converge) and seed 1111 (fails)
    torch.manual_seed(10)
    train()

如果不指定随机种子,则不同的运行具有不同的结果。

为什么在这种情况下,模型无法学习如此简单的任务?我使用nn.Embedding层的方式有什么错误吗?

谢谢

python machine-learning embedding pytorch random-seed
1个回答
0
投票

我发现问题是最终的relu层,在sigmoid之前。如here所述,该层将:

扔掉信息而不增加任何额外的好处

删除该层,网络学习任何种子预期。

© www.soinside.com 2019 - 2024. All rights reserved.