我已经训练了 OpenFold 模型 https://github.com/aqlaboratory/openfold 并获得了检查点文件 (*ckpt) (Pytorch Lighntning)。
请解释一下,如何使用 *.ckpt 文件进行
run_pretrained_openfold.py
的预测?或者也许我需要先以某种方式将其转换为另一种格式?
python3 run_pretrained_openfold.py \ fasta_dir \ data/pdb_mmcif/mmcif_files/ \ --uniref90_database_path data/uniref90/uniref90.fasta \ --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ --pdb70_database_path data/pdb70/pdb70 \ --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --output_dir ./ \ --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --model_device "cuda:0" \ --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \ --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \ --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \ --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \ --config_preset "model_1_ptm" \ **--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt**
如果我使用这个:
--openfold_checkpoint_path /checkpoints/14my.ckpt
我收到此错误消息:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for AlphaFold: Missing key(s) in state_dict: "aux_heads.tm.linear.weight", "aux_heads.tm.linear.bias".
################################################## #
我通过使用
--config_preset "model_1"
而不是 model_1_ptm
解决了该错误。
但是现在我有pdb,其中氨基酸的坐标计算不正确。我使用“手动优化模式”是因为我正在测试自己的优化器,标准 Adam 一般会给出正常结果。
是闪电手动优化的问题还是什么?
所以,我找到了一种将 *ckpt 转换为我自己的 *npz 模型的方法。 首先,在 train_openfold.py 中我添加了:
def convert_to_pt(ckpt_path, output_path):
checkpoint = torch.load(ckpt_path)
model_state_dict = checkpoint['state_dict']
adjusted_state_dict = {}
for key in model_state_dict.keys():
adjusted_state_dict[key.replace('model.', "", 1)] = model_state_dict[key]
torch.save(adjusted_state_dict, output_path)
print(f"Converted checkpoint '{ckpt_path}' to PyTorch state dict '{output_path}'.")
以及之后
if(args.resume_from_ckpt):
convert_to_pt(args.resume_from_ckpt, "/npz/model.pt")
print("Checkpoint converted to pt format and saved to")
然后, 我使用 convert_of_weights_to_jax.py 中生成的 *.pt 来创建 *.npz