【LLM训练系列02】如何找到一个大模型Lora的target_modules

news2024/11/25 19:12:55

方法1:观察attention中的线性层

import numpy as np
import pandas as pd
from peft import PeftModel
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
from typing import List
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
import os
os.environ['CUDA_VISIBLE_DEVICES']='1,2'
os.environ["TOKENIZERS_PARALLELISM"] = "false"


model_path ="/home/jovyan/codes/llms/Qwen2.5-14B-Instruct"
base_model = AutoModel.from_pretrained(model_path, device_map='cuda:0',trust_remote_code=True)



打印attention模型层的名字

for name, module in base_model.named_modules():
    if 'attn' in name or 'attention' in name:  # Common attention module names
        print(name)
        for sub_name, sub_module in module.named_modules():  # Check sub-modules within attention
            print(f"  - {sub_name}")

方法2:通过bitsandbytes量化查找线性层

import bitsandbytes as bnb
def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            # model-specific
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)

加载模型

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
base_model = AutoModel.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map="auto"
    )

查找Lora的目标层

find_all_linear_names(base_model)


还有个函数,一样的原理

def find_target_modules(model):
    # Initialize a Set to Store Unique Layers
    unique_layers = set()
    
    # Iterate Over All Named Modules in the Model
    for name, module in model.named_modules():
        # Check if the Module Type Contains 'Linear4bit'
        if "Linear4bit" in str(type(module)):
            # Extract the Type of the Layer
            layer_type = name.split('.')[-1]
            
            # Add the Layer Type to the Set of Unique Layers
            unique_layers.add(layer_type)

    # Return the Set of Unique Layers Converted to a List
    return list(unique_layers)

find_target_modules(base_model)

方法3:通过分析开源框架的源码swift

代码地址

from collections import OrderedDict
from dataclasses import dataclass, field
from typing import List, Union


@dataclass
class ModelKeys:

    model_type: str = None

    module_list: str = None

    embedding: str = None

    mlp: str = None

    down_proj: str = None

    attention: str = None

    o_proj: str = None

    q_proj: str = None

    k_proj: str = None

    v_proj: str = None

    qkv_proj: str = None

    qk_proj: str = None

    qa_proj: str = None

    qb_proj: str = None

    kva_proj: str = None

    kvb_proj: str = None

    output: str = None


@dataclass
class MultiModelKeys(ModelKeys):
    language_model: Union[List[str], str] = field(default_factory=list)
    connector: Union[List[str], str] = field(default_factory=list)
    vision_tower: Union[List[str], str] = field(default_factory=list)
    generator: Union[List[str], str] = field(default_factory=list)

    def __post_init__(self):
        # compat
        for key in ['language_model', 'connector', 'vision_tower', 'generator']:
            v = getattr(self, key)
            if isinstance(v, str):
                setattr(self, key, [v])
            if v is None:
                setattr(self, key, [])


LLAMA_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.o_proj',
    q_proj='model.layers.{}.self_attn.q_proj',
    k_proj='model.layers.{}.self_attn.k_proj',
    v_proj='model.layers.{}.self_attn.v_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

INTERNLM2_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.feed_forward',
    down_proj='model.layers.{}.feed_forward.w2',
    attention='model.layers.{}.attention',
    o_proj='model.layers.{}.attention.wo',
    qkv_proj='model.layers.{}.attention.wqkv',
    embedding='model.tok_embeddings',
    output='output',
)

CHATGLM_KEYS = ModelKeys(
    module_list='transformer.encoder.layers',
    mlp='transformer.encoder.layers.{}.mlp',
    down_proj='transformer.encoder.layers.{}.mlp.dense_4h_to_h',
    attention='transformer.encoder.layers.{}.self_attention',
    o_proj='transformer.encoder.layers.{}.self_attention.dense',
    qkv_proj='transformer.encoder.layers.{}.self_attention.query_key_value',
    embedding='transformer.embedding',
    output='transformer.output_layer',
)

BAICHUAN_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    qkv_proj='model.layers.{}.self_attn.W_pack',
    embedding='model.embed_tokens',
    output='lm_head',
)

YUAN_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    qk_proj='model.layers.{}.self_attn.qk_proj',
    o_proj='model.layers.{}.self_attn.o_proj',
    q_proj='model.layers.{}.self_attn.q_proj',
    k_proj='model.layers.{}.self_attn.k_proj',
    v_proj='model.layers.{}.self_attn.v_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

CODEFUSE_KEYS = ModelKeys(
    module_list='gpt_neox.layers',
    mlp='gpt_neox.layers.{}.mlp',
    down_proj='gpt_neox.layers.{}.mlp.dense_4h_to_h',
    attention='gpt_neox.layers.{}.attention',
    o_proj='gpt_neox.layers.{}.attention.dense',
    qkv_proj='gpt_neox.layers.{}.attention.query_key_value',
    embedding='gpt_neox.embed_in',
    output='gpt_neox.embed_out',
)

PHI2_KEYS = ModelKeys(
    module_list='transformer.h',
    mlp='transformer.h.{}.mlp',
    down_proj='transformer.h.{}.mlp.c_proj',
    attention='transformer.h.{}.mixer',
    o_proj='transformer.h.{}.mixer.out_proj',
    qkv_proj='transformer.h.{}.mixer.Wqkv',
    embedding='transformer.embd',
    output='lm_head',
)

QWEN_KEYS = ModelKeys(
    module_list='transformer.h',
    mlp='transformer.h.{}.mlp',
    down_proj='transformer.h.{}.mlp.c_proj',
    attention='transformer.h.{}.attn',
    o_proj='transformer.h.{}.attn.c_proj',
    qkv_proj='transformer.h.{}.attn.c_attn',
    embedding='transformer.wte',
    output='lm_head',
)

PHI3_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.o_proj',
    qkv_proj='model.layers.{}.self_attn.qkv_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

PHI3_SMALL_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.dense',
    qkv_proj='model.layers.{}.self_attn.query_key_value',
    embedding='model.embed_tokens',
    output='lm_head',
)

DEEPSEEK_V2_KEYS = ModelKeys(
    module_list='model.layers',
    mlp='model.layers.{}.mlp',
    down_proj='model.layers.{}.mlp.down_proj',
    attention='model.layers.{}.self_attn',
    o_proj='model.layers.{}.self_attn.o_proj',
    qa_proj='model.layers.{}.self_attn.q_a_proj',
    qb_proj='model.layers.{}.self_attn.q_b_proj',
    kva_proj='model.layers.{}.self_attn.kv_a_proj_with_mqa',
    kvb_proj='model.layers.{}.self_attn.kv_b_proj',
    embedding='model.embed_tokens',
    output='lm_head',
)

我的博客即将同步至腾讯云开发者社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=3hiaca88ulogc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2247454.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何选择服务器

如何选择服务器 选择服务器时应考虑以下几个关键因素: 性能需求。根据网站的预期流量和负载情况,选择合适的处理器、内存和存储容量。考虑网站是否需要处理大量动态内容或高分辨率媒体文件。 可扩展性。选择一个可以轻松扩展的服务器架构,以便…

C++共享智能指针

C中没有垃圾回收机制,必须自己释放分配的内存,否则就会造成内存泄漏。解决这个问题最有效的方式是使用智能指针。 智能指针是存储指向动态分配(堆)对象指针的类,用于生存期的控制,能够确保在离开指针所在作用域时,自动…

python Flask指定IP和端口

from flask import Flask, request import uuidimport json import osapp Flask(__name__)app.route(/) def hello_world():return Hello, World!if __name__ __main__:app.run(host0.0.0.0, port5000)

虚幻引擎---初识篇

一、学习途径 虚幻引擎官方文档:https://dev.epicgames.com/documentation/zh-cn/unreal-engine/unreal-engine-5-5-documentation虚幻引擎在线学习平台:https://dev.epicgames.com/community/unreal-engine/learning哔哩哔哩:https://www.b…

Java开发经验——SpringRestTemplate常见错误

摘要 本文分析了在使用Spring框架的RestTemplate发送表单请求时遇到的常见错误。主要问题在于将表单参数错误地以JSON格式提交,导致服务器无法正确解析参数。文章提供了错误案例的分析,并提出了修正方法。 1. 表单参数类型是MultiValueMap RestControl…

oracle会话追踪

一 跟踪当前会话 1.1 查看当前会话的SID,SERIAL# #在当前会话里执行 示例: SQL> select distinct userenv(sid) from v$mystat; USERENV(SID) -------------- 1945 SQL> select distinct sid,serial# from v$session where sid1945; SID SERIAL# …

数据可视化复习2-绘制折线图+条形图(叠加条形图,并列条形图,水平条形图)+ 饼状图 + 直方图

目录 目录 一、绘制折线图 1.使用pyplot 2.使用numpy ​编辑 3.使用DataFrame ​编辑 二、绘制条形图(柱状图) 1.简单条形图 2.绘制叠加条形图 3.绘制并列条形图 4.水平条形图 ​编辑 三、绘制饼状图 四、绘制散点图和直方图 1.散点图 2…

postgresql按照年月日统计历史数据

1.按照日 SELECT a.time,COALESCE(b.counts,0) as counts from ( SELECT to_char ( b, YYYY-MM-DD ) AS time FROM generate_series ( to_timestamp ( 2024-06-01, YYYY-MM-DD hh24:mi:ss ), to_timestamp ( 2024-06-30, YYYY-MM-DD hh24:mi:ss ), 1 days ) AS b GROUP BY tim…

【JavaEE初阶 — 多线程】定时器的应用及模拟实现

目录 1. 标准库中的定时器 1.1 Timer 的定义 1.2 Timer 的原理 1.3 Timer 的使用 1.4 Timer 的弊端 1.5 ScheduledExecutorService 2. 模拟实现定时器 2.1 实现定时器的步骤 2.1.1 定义类描述任务 定义类描述任务 第一种定义方法 …

一文学会Golang里拼接字符串的6种方式(性能对比)

g o l a n g golang golang的 s t r i n g string string类型是不可修改的,对于拼接字符串来说,本质上还是创建一个新的对象将数据放进去。主要有以下几种拼接方式 拼接方式介绍 1.使用 s t r i n g string string自带的运算符 ans ans s2. 使用…

LeetCode 3244.新增道路查询后的最短距离 II:贪心(跃迁合并)-9行py(O(n))

【LetMeFly】3244.新增道路查询后的最短距离 II:贪心(跃迁合并)-9行py(O(n)) 力扣题目链接:https://leetcode.cn/problems/shortest-distance-after-road-addition-queries-ii/ 给你一个整数 n 和一个二维…

MyBatis中特殊SQL的执行

目录 1.模糊查询 2.批量删除 3.动态设置表名 4.添加功能获取自增的主键 1.模糊查询 List<User> getUserByLike(Param("username") String username); <select id"getUserByLike" resultType"com.atguigu.mybatis.pojo.User">&…

ES 基本使用与二次封装

概述 基本了解 Elasticsearch 是一个开源的分布式搜索和分析引擎&#xff0c;基于 Apache Lucene 构建。它提供了对海量数据的快速全文搜索、结构化搜索和分析功能&#xff0c;是目前流行的大数据处理工具之一。主要特点即高效搜索、分布式存储、拓展性强 核心功能 全文搜索:…

Azkaban部署

首先我们需要现在相关的组件&#xff0c;在这里已经给大家准备好了相关的安装包&#xff0c;有需要的可以自行下载。 只需要启动hadoop集群就可以&#xff0c;如果现在你的hive是打开的&#xff0c;那么请你关闭&#xff01;&#xff01;&#xff01; 如果不关会造成证书冲突…

Jmeter中的定时器

4&#xff09;定时器 1--固定定时器 功能特点 固定延迟&#xff1a;在每个请求之间添加固定的延迟时间。精确控制&#xff1a;可以精确控制请求的发送频率。简单易用&#xff1a;配置简单&#xff0c;易于理解和使用。 配置步骤 添加固定定时器 右键点击需要添加定时器的请求…

JavaEE初学07

JavaEE初学07 MybatisORMMybatis一对一结果映射一对多结果映射 Mybatis动态sqlif标签trim标签where标签set标签foreach标签补充 Mybatis Mybatis是一款优秀的持久层框架&#xff0c;他支持自定义SQL、存储过程以及高级映射。Mybatis几乎免除了所有的JDBC代码以及设置参数和获取…

【layui】table的switch、edit修改

<title>简单表格数据</title><div class"layui-card layadmin-header"><div class"layui-breadcrumb" lay-filter"breadcrumb"><a>系统设置</a><a>简单表格数据</a></div> </div>&…

工具使用_docker容器_crossbuild

1. 工具简介 2. 工具使用 拉取 multiarch/crossbuild 镜像&#xff1a; docker pull multiarch/crossbuild 创建工作目录和示例代码&#xff1a; mkdir -p ~/crossbuild-test cd ~/crossbuild-test 创建 helloworld.c &#xff1a; #include <stdio.h>int main() …

Android 天气APP(三十七)新版AS编译、更新镜像源、仓库源、修复部分BUG

上一篇&#xff1a;Android 天气APP&#xff08;三十六&#xff09;运行到本地AS、更新项目版本依赖、去掉ButterKnife 新版AS编译、更新镜像源、仓库源、修复部分BUG 前言正文一、更新镜像源① 腾讯源③ 阿里源 二、更新仓库源三、修复城市重名BUG四、地图加载问题五、源码 前…

基于Java Springboot海洋馆预约系统

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术&#xff1a;Html、Css、Js、Vue、Element-ui 数据库&#xff1a;MySQL 后端技术&#xff1a;Java、Spring Boot、MyBatis 三、运行环境 开发工具&#xff1a;IDEA/eclipse 数据…