我使用 LSTM 代码:https://github.com/Khamies/LSTM-Variational-AutoEncoder 当我使用自己的土耳其语数据而不是默认数据时遇到这样的错误。模型有问题?不知道是不是数据集的问题?我该怎么办?
我在 Jupyter Notebook 中运行这个单元格
vocab_size = train_data.vocab_size
model = LSTM_VAE(vocab_size = vocab_size, embed_size = embed_size, `hidden_size = hidden_size, latent_size = latent_size).to(device)`
checkpoint = torch.load("models/LSTM_VAE.pt",map_location=torch.device('cpu')) for i in checkpoint:
print(i)
model.load_state_dict(checkpoint["model"])
默认英文数据集结果
model
optimizer
<All keys matched successfully>
我的土耳其语数据集结果
embed.weight
encoder_lstm.weight_ih_l0
encoder_lstm.weight_hh_l0
encoder_lstm.bias_ih_l0
encoder_lstm.bias_hh_l0
mean.weight
mean.bias
log_variance.weight
log_variance.bias
init_hidden_decoder.weight
init_hidden_decoder.bias
decoder_lstm.weight_ih_l0
decoder_lstm.weight_hh_l0
decoder_lstm.bias_ih_l0
decoder_lstm.bias_hh_l0
output.weight
output.bias
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-5-db2fa1007134> in <module>
5 for i in checkpoint:
6 print(i)
----> 7 model.load_state_dict(checkpoint["model"])
KeyError: 'model'
我的数据集https://drive.google.com/drive/folders/1ZWcmUiCdKOxOxdGq7dmKcBBRNSC-IPI5?usp=share_link
我希望模型生成和代码按照您想要的方式工作。