如何在OpenFold中使用*.ckpt文件作为模型?

问题描述 投票:0回答:1

我已经训练了 OpenFold 模型 https://github.com/aqlaboratory/openfold 并获得了检查点文件 (*ckpt) (Pytorch Lighntning)。

请解释一下,如何使用 *.ckpt 文件进行

run_pretrained_openfold.py
的预测?或者也许我需要先以某种方式将其转换为另一种格式?

python3 run_pretrained_openfold.py \ fasta_dir \ data/pdb_mmcif/mmcif_files/ \ --uniref90_database_path data/uniref90/uniref90.fasta \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ --pdb70_database_path data/pdb70/pdb70 \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --output_dir ./ \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --model_device "cuda:0" \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \ --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \ --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \ --config_preset "model_1_ptm" \ **--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt**

如果我使用这个:

 --openfold_checkpoint_path /checkpoints/14my.ckpt
我收到此错误消息:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for AlphaFold:                                                                     Missing key(s) in state_dict: "aux_heads.tm.linear.weight", "aux_heads.tm.linear.bias".

################################################## #

我通过使用

--config_preset "model_1"
而不是
model_1_ptm
解决了该错误。

但是现在我有pdb,其中氨基酸的坐标计算不正确。我使用“手动优化模式”是因为我正在测试自己的优化器,标准 Adam 一般会给出正常结果。

是闪电手动优化的问题还是什么?

手动优化后的肽段: enter image description here

标准 Adam 和自动优化: enter image description here

pytorch pytorch-lightning
1个回答
0
投票

所以,我找到了一种将 *ckpt 转换为我自己的 *npz 模型的方法。 首先,在 train_openfold.py 中我添加了:

def convert_to_pt(ckpt_path, output_path):

checkpoint = torch.load(ckpt_path)
model_state_dict = checkpoint['state_dict']

adjusted_state_dict = {}
for key in model_state_dict.keys():
    adjusted_state_dict[key.replace('model.', "", 1)] = model_state_dict[key]

torch.save(adjusted_state_dict, output_path)
print(f"Converted checkpoint '{ckpt_path}' to PyTorch state dict '{output_path}'.")

以及之后

if(args.resume_from_ckpt):

convert_to_pt(args.resume_from_ckpt, "/npz/model.pt")
print("Checkpoint converted to pt format and saved to")

然后, 我使用 convert_of_weights_to_jax.py 中生成的 *.pt 来创建 *.npz

© www.soinside.com 2019 - 2024. All rights reserved.