PyTorch LSTM 模型未训练

问题描述 投票:0回答:1

我正在尝试训练一个 LSTM 模型,后接一个全连接层,以对一组具有 22 个通道(序列长度为 1000)的 EEG 时间序列数据进行分类。我使用 PyTorch 作为图层,但是在训练过程中验证准确性并没有提高。如果说有什么不同的话,那就是验证变得越来越糟糕。我假设这里发生了很多过度拟合,但我很困惑为什么它仍然与随机选择一个类(我有 4 个类)表现相同。有谁知道这是为什么吗?我附上了用于创建模型、训练模型的代码以及训练期间的一些输出。

代码

num_layers = 1
input_size = 22
hidden_size = 32
num_classes = 4
batch_size = 50
num_epochs = 20

class LSTMClassifier(nn.Module):
  def __init__(self, input_size, hidden_size, num_classes):
      super(LSTMClassifier, self).__init__()
      self.hidden_size = hidden_size
      self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=0.5, num_layers=num_layers)
      self.fc = nn.Linear(hidden_size, num_classes)

  def forward(self, x):
      # Set initial hidden and cell states
      h0 = torch.zeros(num_layers, x.size(0), self.hidden_size).to(x.device)
      c0 = torch.zeros(num_layers, x.size(0), self.hidden_size).to(x.device)

      # Forward propagate LSTM
      out, _ = self.lstm(x, (h0, c0))

      # Decode the hidden state of the last time step
      out = self.fc(out[:, -1, :])
      return out


# Instantiate the model
model = LSTMClassifier(input_size, hidden_size, num_classes)
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for i in range(0, len(X_train), batch_size):
        optimizer.zero_grad()

        # Get batch data
        batch_input = X_train[i:i+batch_size, :, :input_size]  # Ensure input size matches
        batch_target = y_train[i:i+batch_size]

        # Forward pass
        outputs = model(batch_input)

        # Calculate loss
        loss = criterion(outputs, batch_target)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += batch_target.size(0)
        correct += (predicted == batch_target).sum().item()

    accuracy = 100 * correct / total
    # Validation
    model.eval()
    total_loss_valid, correct_valid, total_valid = 0.0, 0, 0

    with torch.no_grad():
      outputs = model(X_valid)
      loss_valid = criterion(outputs, y_valid)
      total_loss_valid += loss_valid.item()

      _, predicted_valid = torch.max(outputs, 1)
      total_valid += y_valid.size(0)
      correct_valid += (predicted_valid == y_valid).sum().item()

    accuracy_valid = 100 * correct_valid / total_valid

    print("Epoch:", epoch+1, "\t\tTraining Loss:", total_loss, "\tTraining Accuracy:", accuracy)
    print("\t\t\tValidation Loss:", total_loss_valid, "\tValidation Accuracy:", accuracy_valid)

print("Model training complete.")

训练输出

Epoch: 1        Training Loss: 47.797505140304565   Training Accuracy: 24.231678486997637
            Validation Loss: 1.3961728811264038     Validation Accuracy: 23.87706855791962
Epoch: 2        Training Loss: 46.88989615440369    Training Accuracy: 27.77777777777778
            Validation Loss: 1.3933675289154053     Validation Accuracy: 24.34988179669031
Epoch: 3        Training Loss: 46.60832667350769    Training Accuracy: 30.61465721040189
            Validation Loss: 1.3946889638900757     Validation Accuracy: 28.368794326241133
Epoch: 4        Training Loss: 46.370394468307495   Training Accuracy: 31.85579196217494
            Validation Loss: 1.3951939344406128     Validation Accuracy: 27.89598108747045
Epoch: 5        Training Loss: 46.11502695083618    Training Accuracy: 34.1016548463357
            Validation Loss: 1.3980598449707031     Validation Accuracy: 27.659574468085108
Epoch: 6        Training Loss: 45.86429834365845    Training Accuracy: 35.579196217494086
            Validation Loss: 1.399003267288208  Validation Accuracy: 26.24113475177305
Epoch: 7        Training Loss: 45.610363483428955   Training Accuracy: 37.35224586288416
            Validation Loss: 1.4000375270843506     Validation Accuracy: 24.34988179669031
Epoch: 8        Training Loss: 45.35917508602142    Training Accuracy: 37.70685579196218
            Validation Loss: 1.4008349180221558     Validation Accuracy: 26.24113475177305
Epoch: 9        Training Loss: 45.06815278530121    Training Accuracy: 39.657210401891255
            Validation Loss: 1.4022436141967773     Validation Accuracy: 25.29550827423168
Epoch: 10       Training Loss: 44.72212290763855    Training Accuracy: 41.016548463356976
            Validation Loss: 1.4047679901123047     Validation Accuracy: 23.87706855791962
Epoch: 11       Training Loss: 44.39665925502777    Training Accuracy: 41.78486997635934
            Validation Loss: 1.4072729349136353     Validation Accuracy: 25.768321513002363
Epoch: 12       Training Loss: 44.04106819629669    Training Accuracy: 42.434988179669034
            Validation Loss: 1.409785509109497  Validation Accuracy: 24.34988179669031
Epoch: 13       Training Loss: 43.70037293434143    Training Accuracy: 43.73522458628842
            Validation Loss: 1.4153059720993042     Validation Accuracy: 25.059101654846337
Epoch: 14       Training Loss: 43.31694221496582    Training Accuracy: 45.62647754137116
            Validation Loss: 1.4185421466827393     Validation Accuracy: 23.87706855791962
Epoch: 15       Training Loss: 42.95249891281128    Training Accuracy: 46.74940898345154
            Validation Loss: 1.425550937652588  Validation Accuracy: 24.58628841607565
Epoch: 16       Training Loss: 42.523557305336  Training Accuracy: 47.75413711583924
            Validation Loss: 1.4291025400161743     Validation Accuracy: 24.113475177304963
Epoch: 17       Training Loss: 42.142807960510254   Training Accuracy: 48.758865248226954
            Validation Loss: 1.4345715045928955     Validation Accuracy: 24.822695035460992
Epoch: 18       Training Loss: 41.72265028953552    Training Accuracy: 49.645390070921984
            Validation Loss: 1.4420568943023682     Validation Accuracy: 24.34988179669031
Epoch: 19       Training Loss: 41.276421666145325   Training Accuracy: 50.354609929078016
            Validation Loss: 1.449327826499939  Validation Accuracy: 25.53191489361702
Epoch: 20       Training Loss: 40.81519865989685    Training Accuracy: 51.59574468085106
            Validation Loss: 1.4553029537200928     Validation Accuracy: 25.29550827423168
Model training complete.
python pytorch lstm
1个回答
0
投票

你的训练损失在增加,而验证损失在增加,这几乎就是过度拟合的定义。

模型过度拟合的原因可能有很多,具体取决于您的设置。大多数时候,这与数据太少(您应该报告数据集大小)或模型太大(看看您的模型,如果是这种情况我会感到惊讶)有关。

一般来说,我建议您查找外部资源以了解过度拟合的情况及其发生的原因。 这个很好

话虽这么说,我认为您在 4 类分类上仅获得 25% 的准确率是非常可疑的。除非你的数据很少,否则我会说数据加载或其他方面可能还存在另一个问题。不幸的是,鉴于您提供的代码,我无法评估这部分。

© www.soinside.com 2019 - 2024. All rights reserved.