如何缓解lstm的过度拟合问题,或者也许我误解了lstm训练?

问题描述 投票:0回答:1

我的目标是通过篮球的轨迹来预测投篮是否命中(已经有论文做了同样的事情,我只是复制它们(https://arxiv.org/abs/1608.03793) )。这是一个非常简单的问题,但出于其他目的我想尝试 lstm 的预测。但网络不仅收敛速度慢,而且很快就开始过度拟合。 时间序列的长度为 12,特征维度为 4(x,y,z 坐标和时间)


class PolicyHead(nn.Layer):
    def __init__(self):
        super(PolicyHead, self).__init__()
        self.lstm = LSTM(4,64,2,dropout=0.3)
        self.fc = nn.Linear(64,2)

    def forward(self, past_traj):
        b,n,c = past_traj.shape
         
        return self.fc(self.lstm(past_traj, batch_id))

enter image description here

作为比较,使用了非常简单的 mlp 结构,并且预测非常好。

class PolicyHead(nn.Layer):
    def __init__(self):
        super(PolicyHead, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(12*4, 512),
            nn.ReLU(),
            nn.Linear(512,512),
            nn.ReLU(),
            nn.Linear(512,2)
        )

    def forward(self, past_traj):
        b,n,c = past_traj.shape
        return self.fc(past_traj.reshape([b,-1]))

enter image description here

python lstm
1个回答
0
投票

我没有看到你在哪里设置训练优化器,但你也可以使用L1/L2正则化来更新学习率;这将有助于解决过度拟合问题,而且我认为无论您对过度拟合问题做出什么决定,这都是有益的。

现在要直接解决过度拟合问题,您可以添加批量归一化层,将 dropout 率调整为 0.5 和/或使用循环 dropout(您可以一起使用它们)。

当数据具有高方差并且批量大小已经很大时,批量归一化非常有用。

dropout和recurrent dropout只在训练过程中起作用,所以不用担心影响推理阶段。

class PolicyHead(nn.Layer):
    def __init__(self):
        super(PolicyHead, self).__init__()
        self.lstm = LSTM(4,64,2,dropout=0.3, recurrent_dropout=0.3)
        self.fc = nn.Linear(64,2)

    def forward(self, past_traj):
        b,n,c = past_traj.shape
         
        return self.fc(self.lstm(past_traj, batch_id))

你可以一起使用它们,但技巧是尝试一个错误,单独使用循环丢失,尝试一些值,并测试它的行为。

现在,如果您想应用 BatchNormalization,您将需要多更新一点。在你的情况下

class PolicyHead(nn.Layer):
    def __init__(self):
        super(PolicyHead, self).__init__()
        self.lstm = LSTM(4,64,2,dropout=0.3, recurrent_dropout=0.3) #Assuming this is nn.LSTM
        self.batch_norm = nn.BatchNorm1d(64)
        self.fc = nn.Linear(64,2)

    def forward(self, past_traj):
        b,n,c = past_traj.shape
        lstm_out = self.lstm(past_traj, batch_id)[:-1:]
        return self.fc(self.batch_norm(lstm_out)

希望这有帮助

© www.soinside.com 2019 - 2024. All rights reserved.