LoraConfig
对象包含一个 target_modules
数组。在某些示例中,目标模块是 ["query_key_value"]
,有时是 ["q", "v"]
,有时是其他东西。
我不太明白目标模块的值从哪里来。我应该在模型页面的哪里查看 LoRA 适配模块是什么?
一个示例(针对型号 Falcon 7B):
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
另一个示例(针对型号 Opt-6.7B):
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
另一个(适用于型号 Flan-T5-xxl):
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM
)
假设您加载了您选择的某个模型:
model = AutoModelForCausalLM.from_pretrained("some-model-checkpoint")
然后你可以通过打印这个模型来看到可用的模块:
print(model)
您将得到类似这样的信息(SalesForce/CodeGen25):
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(51200, 4096, padding_idx=0)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(act_fn): SiLUActivation()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=51200, bias=False)
)
在我的例子中,您可以找到包含 q_proj、k_proj、v_proj 和 o_proj 的 LLamaAttention 模块。这是一些可用于 LoRA 的模块。
我建议您阅读LoRA论文中有关使用哪些模块的更多信息。
这里是获得所有线性的方法。
def find_all_linear_names(model):
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)