我正在尝试训练一个 LSTM 模型,后接一个全连接层,以对一组具有 22 个通道(序列长度为 1000)的 EEG 时间序列数据进行分类。我使用 PyTorch 作为图层,但是在训练过程中验证准确性并没有提高。如果说有什么不同的话,那就是验证变得越来越糟糕。我假设这里发生了很多过度拟合,但我很困惑为什么它仍然与随机选择一个类(我有 4 个类)表现相同。有谁知道这是为什么吗?我附上了用于创建模型、训练模型的代码以及训练期间的一些输出。
num_layers = 1
input_size = 22
hidden_size = 32
num_classes = 4
batch_size = 50
num_epochs = 20
class LSTMClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(LSTMClassifier, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=0.5, num_layers=num_layers)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Set initial hidden and cell states
h0 = torch.zeros(num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(num_layers, x.size(0), self.hidden_size).to(x.device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
# Instantiate the model
model = LSTMClassifier(input_size, hidden_size, num_classes)
model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(num_epochs):
model.train()
total_loss, correct, total = 0.0, 0, 0
for i in range(0, len(X_train), batch_size):
optimizer.zero_grad()
# Get batch data
batch_input = X_train[i:i+batch_size, :, :input_size] # Ensure input size matches
batch_target = y_train[i:i+batch_size]
# Forward pass
outputs = model(batch_input)
# Calculate loss
loss = criterion(outputs, batch_target)
# Backward pass and optimize
loss.backward()
optimizer.step()
total_loss += loss.item()
# Calculate accuracy
_, predicted = torch.max(outputs, 1)
total += batch_target.size(0)
correct += (predicted == batch_target).sum().item()
accuracy = 100 * correct / total
# Validation
model.eval()
total_loss_valid, correct_valid, total_valid = 0.0, 0, 0
with torch.no_grad():
outputs = model(X_valid)
loss_valid = criterion(outputs, y_valid)
total_loss_valid += loss_valid.item()
_, predicted_valid = torch.max(outputs, 1)
total_valid += y_valid.size(0)
correct_valid += (predicted_valid == y_valid).sum().item()
accuracy_valid = 100 * correct_valid / total_valid
print("Epoch:", epoch+1, "\t\tTraining Loss:", total_loss, "\tTraining Accuracy:", accuracy)
print("\t\t\tValidation Loss:", total_loss_valid, "\tValidation Accuracy:", accuracy_valid)
print("Model training complete.")
Epoch: 1 Training Loss: 47.797505140304565 Training Accuracy: 24.231678486997637
Validation Loss: 1.3961728811264038 Validation Accuracy: 23.87706855791962
Epoch: 2 Training Loss: 46.88989615440369 Training Accuracy: 27.77777777777778
Validation Loss: 1.3933675289154053 Validation Accuracy: 24.34988179669031
Epoch: 3 Training Loss: 46.60832667350769 Training Accuracy: 30.61465721040189
Validation Loss: 1.3946889638900757 Validation Accuracy: 28.368794326241133
Epoch: 4 Training Loss: 46.370394468307495 Training Accuracy: 31.85579196217494
Validation Loss: 1.3951939344406128 Validation Accuracy: 27.89598108747045
Epoch: 5 Training Loss: 46.11502695083618 Training Accuracy: 34.1016548463357
Validation Loss: 1.3980598449707031 Validation Accuracy: 27.659574468085108
Epoch: 6 Training Loss: 45.86429834365845 Training Accuracy: 35.579196217494086
Validation Loss: 1.399003267288208 Validation Accuracy: 26.24113475177305
Epoch: 7 Training Loss: 45.610363483428955 Training Accuracy: 37.35224586288416
Validation Loss: 1.4000375270843506 Validation Accuracy: 24.34988179669031
Epoch: 8 Training Loss: 45.35917508602142 Training Accuracy: 37.70685579196218
Validation Loss: 1.4008349180221558 Validation Accuracy: 26.24113475177305
Epoch: 9 Training Loss: 45.06815278530121 Training Accuracy: 39.657210401891255
Validation Loss: 1.4022436141967773 Validation Accuracy: 25.29550827423168
Epoch: 10 Training Loss: 44.72212290763855 Training Accuracy: 41.016548463356976
Validation Loss: 1.4047679901123047 Validation Accuracy: 23.87706855791962
Epoch: 11 Training Loss: 44.39665925502777 Training Accuracy: 41.78486997635934
Validation Loss: 1.4072729349136353 Validation Accuracy: 25.768321513002363
Epoch: 12 Training Loss: 44.04106819629669 Training Accuracy: 42.434988179669034
Validation Loss: 1.409785509109497 Validation Accuracy: 24.34988179669031
Epoch: 13 Training Loss: 43.70037293434143 Training Accuracy: 43.73522458628842
Validation Loss: 1.4153059720993042 Validation Accuracy: 25.059101654846337
Epoch: 14 Training Loss: 43.31694221496582 Training Accuracy: 45.62647754137116
Validation Loss: 1.4185421466827393 Validation Accuracy: 23.87706855791962
Epoch: 15 Training Loss: 42.95249891281128 Training Accuracy: 46.74940898345154
Validation Loss: 1.425550937652588 Validation Accuracy: 24.58628841607565
Epoch: 16 Training Loss: 42.523557305336 Training Accuracy: 47.75413711583924
Validation Loss: 1.4291025400161743 Validation Accuracy: 24.113475177304963
Epoch: 17 Training Loss: 42.142807960510254 Training Accuracy: 48.758865248226954
Validation Loss: 1.4345715045928955 Validation Accuracy: 24.822695035460992
Epoch: 18 Training Loss: 41.72265028953552 Training Accuracy: 49.645390070921984
Validation Loss: 1.4420568943023682 Validation Accuracy: 24.34988179669031
Epoch: 19 Training Loss: 41.276421666145325 Training Accuracy: 50.354609929078016
Validation Loss: 1.449327826499939 Validation Accuracy: 25.53191489361702
Epoch: 20 Training Loss: 40.81519865989685 Training Accuracy: 51.59574468085106
Validation Loss: 1.4553029537200928 Validation Accuracy: 25.29550827423168
Model training complete.
你的训练损失在增加,而验证损失在增加,这几乎就是过度拟合的定义。
模型过度拟合的原因可能有很多,具体取决于您的设置。大多数时候,这与数据太少(您应该报告数据集大小)或模型太大(看看您的模型,如果是这种情况我会感到惊讶)有关。
一般来说,我建议您查找外部资源以了解过度拟合的情况及其发生的原因。 这个很好。
话虽这么说,我认为您在 4 类分类上仅获得 25% 的准确率是非常可疑的。除非你的数据很少,否则我会说数据加载或其他方面可能还存在另一个问题。不幸的是,鉴于您提供的代码,我无法评估这部分。