我正在尝试在 LLAMA 2 上使用 Huggingface PEFT LORA 微调进行 Pytorch Lightning Fabric 分布式 FSDP 训练,但我的代码最终失败:
`FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16
File ".......", line 100, in <module>
model, optimizer = fabric.setup(model, optimizer)
ValueError: `FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16
如何找出 pytorch 结构中的哪些张量是 float32 类型?