在不同模型上应用PEFT / LoRA的目标模块

问题描述 投票:0回答:2

我正在查看一些在不同模型上使用 PEFT 的不同示例

LoraConfig
对象包含一个
target_modules
数组。在某些示例中,目标模块是
["query_key_value"]
,有时是
["q", "v"]
,有时是其他东西。

我不太明白目标模块的值从哪里来。我应该在模型页面的哪里查看 LoRA 适配模块是什么?

一个示例(针对型号 Falcon 7B):

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]

另一个示例(针对型号 Opt-6.7B):

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

另一个(适用于型号 Flan-T5-xxl):

lora_config = LoraConfig(
 r=16,
 lora_alpha=32,
 target_modules=["q", "v"],
 lora_dropout=0.05,
 bias="none",
 task_type=TaskType.SEQ_2_SEQ_LM
)
nlp huggingface-transformers huggingface fine-tuning peft
2个回答
20
投票

假设您加载了您选择的某个模型:

model = AutoModelForCausalLM.from_pretrained("some-model-checkpoint")

然后你可以通过打印这个模型来看到可用的模块:

print(model)

您将得到类似这样的信息(SalesForce/CodeGen25):

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(51200, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=51200, bias=False)
)

在我的例子中,您可以找到包含 q_proj、k_proj、v_proj 和 o_proj 的 LLamaAttention 模块。这是一些可用于 LoRA 的模块。

我建议您阅读LoRA论文中有关使用哪些模块的更多信息。


0
投票

这里是获得所有线性的方法。

def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)
© www.soinside.com 2019 - 2024. All rights reserved.