langchain-ChatGLM源码阅读:模型加载

news2025/1/23 7:19:24

文章目录

    • 使用命令行参数初始化加载器
    • 模型实例化
    • 清空显存
    • 加载模型调用链
        • `loader.py`的`_load_model`方法
        • `auto_factory.py`的`from_pretrained`方法
        • `modeling_utils.py`的`from_pretrained`方法
        • `hub.py`的`get_checkpoint_shard_files`方法
        • `modeling_utils.py`的`_load_pretrained_mode`方法
        • 回到`loader.py`的`_load_model`方法

使用命令行参数初始化加载器

loader.py

def __init__(self, params: dict = None):
        """
        模型初始化
        :param params:
        """
        self.model = None
        self.tokenizer = None
        self.params = params or {}
        self.model_name = params.get('model_name', False)
        self.model_path = params.get('model_path', None)
        self.no_remote_model = params.get('no_remote_model', False)
        self.lora = params.get('lora', '')
        self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
        self.lora_dir = params.get('lora_dir', '')
        self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
        self.load_in_8bit = params.get('load_in_8bit', False)
        self.bf16 = params.get('bf16', False)

        self.is_chatgmlcpp = "chatglm2-cpp" == self.model_name

模型实例化

shared.py

def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> Any:
    """
    init llm_model_ins LLM
    :param llm_model: model_name
    :param no_remote_model:  remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
    :param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
    :return:
    """
    # 默认为chatglm2-6b-32k
    pre_model_name = loaderCheckPoint.model_name
    # model_config中chatglm2-6b-32k对应参数
    llm_model_info = llm_model_dict[pre_model_name]

    if no_remote_model:
        loaderCheckPoint.no_remote_model = no_remote_model
    if use_ptuning_v2:
        loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2

    # 如果指定了参数,则使用参数的配置,默认为none
    if llm_model:
        llm_model_info = llm_model_dict[llm_model]

    loaderCheckPoint.model_name = llm_model_info['name']
    # 默认为THUDM/chatglm2-6b-32k
    loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
    # 需手动指定路径
    loaderCheckPoint.model_path = llm_model_info["local_model_path"]
    # ChatGLMLLMChain
    if 'FastChatOpenAILLM' in llm_model_info["provides"]:
        loaderCheckPoint.unload_model()
    else:
        loaderCheckPoint.reload_model()
    # 根据名称自动加载类:<class 'models.chatglm_llm.ChatGLMLLMChain'>
    provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
    # 将类实例化为模型对象
    modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
    if 'FastChatOpenAILLM' in llm_model_info["provides"]:
        modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
        modelInsLLM.call_model_name(llm_model_info['name'])
        modelInsLLM.set_api_key(llm_model_info['api_key'])
    return modelInsLLM

loader.py

    def reload_model(self):
        self.unload_model()
        self.model_config = self._load_model_config()

        if self.use_ptuning_v2:
            try:
                prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
                prefix_encoder_config = json.loads(prefix_encoder_file.read())
                prefix_encoder_file.close()
                self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
                self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
            except Exception as e:
                print(e)
                print("加载PrefixEncoder config.json失败")

        self.model, self.tokenizer = self._load_model()

        if self.lora:
            self._add_lora_to_model([self.lora])

        if self.use_ptuning_v2:
            try:
                prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
                new_prefix_state_dict = {}
                for k, v in prefix_state_dict.items():
                    if k.startswith("transformer.prefix_encoder."):
                        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
                self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
                self.model.transformer.prefix_encoder.float()
                print("加载ptuning检查点成功!")
            except Exception as e:
                print(e)
                print("加载PrefixEncoder模型参数失败")
        # llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
        if not self.is_llamacpp and not self.is_chatgmlcpp:
            self.model = self.model.eval()

清空显存

在加载模型前先清空显存
loader.py

    def unload_model(self):
        del self.model
        del self.tokenizer
        self.model = self.tokenizer = None
        self.clear_torch_cache()
        
    def clear_torch_cache(self):
        # 垃圾回收, 避免内存泄漏和优化内存使用
        gc.collect()
        if self.llm_device.lower() != "cpu":
            # 检测系统是否支持MPS,这是是Apple在Mac设备上用于GPU加速的框架
            if torch.has_mps:
                try:
                    from torch.mps import empty_cache
                    empty_cache()
                except Exception as e:
                    print(e)
                    print(
                        "如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
            elif torch.has_cuda:
                device_id = "0" if torch.cuda.is_available() and (":" not in self.llm_device) else None
                CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
                with torch.cuda.device(CUDA_DEVICE):
                    # 释放GPU显存缓存中的任何未使用的内存。
                    # PyTorch在GPU上申请和释放内存时,部分内存会保留在缓存中重复利用,
                    # empty_cache()可以释放这些缓存memory。
                    torch.cuda.empty_cache()
                    # 用于CUDA IPC内存共享的垃圾回收。
                    # 在多进程GPU训练中,进程间会共享部分内存,
                    # ipc_collect()可以显式收集共享内存垃圾。
                    torch.cuda.ipc_collect()
            else:
                print("未检测到 cuda 或 mps,暂不支持清理显存")

加载模型调用链

loader.py_load_model方法

model = LoaderClass.from_pretrained(checkpoint,
                                                        config=self.model_config,
                                                        torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
                                                        trust_remote_code=True).half()

auto_factory.pyfrom_pretrained方法

包路径:site-packages/transformers/models/auto/auto_factory.py
作用:将配置对象的类与模型类或对象建立关联,以便根据配置来获取相应的模型类或对象。这通常用于管理不同配置下的模型选择和实例化。例如,根据不同的配置选择不同的模型架构或模型参数。

cls.register(config.__class__, model_class, exist_ok=True)

modeling_utils.pyfrom_pretrained方法

包路径:site-packages/transformers/modeling_utils.py
作用:因为没有显式指定模型路径,所以只能通过缓存方式下载和加载。

                    resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

                    # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
                    # result when internet is up, the repo and revision exist, but the file does not.
                    if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
                        # Maybe the checkpoint is sharded, we try to grab the index name in this case.
                        resolved_archive_file = cached_file(
                            pretrained_model_name_or_path,
                            _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
                            **cached_file_kwargs,
                        )
			
			...
			
        # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
        if is_sharded:
            # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
            resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
                pretrained_model_name_or_path,
                resolved_archive_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
                use_auth_token=token,
                user_agent=user_agent,
                revision=revision,
                subfolder=subfolder,
                _commit_hash=commit_hash,
            )

hub.pyget_checkpoint_shard_files方法

包路径:site-packages/transformers/utils/hub.py
作用:第一次启动项目时下载模型到本地缓存。

    for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
        try:
            # Load from URL
            cached_filename = cached_file(
                pretrained_model_name_or_path,
                shard_filename,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                user_agent=user_agent,
                revision=revision,
                subfolder=subfolder,
                _commit_hash=_commit_hash,
            )

modeling_utils.py_load_pretrained_mode方法

包路径:site-packages/transformers/modeling_utils.py
作用:遍历权重文件分片,逐一加载这些分片,但会跳过那些只包含磁盘上载权重的分片文件,显示加载的进度条,也就是下面这个东西,但此时模型权重还没有加载到显存中

在这里插入图片描述

            if len(resolved_archive_file) > 1:
                resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
            for shard_file in resolved_archive_file:
                # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
                if shard_file in disk_only_shard_files:
                    continue
                state_dict = load_state_dict(shard_file)

回到loader.py_load_model方法

这里主要是为了把模型加载到显存,可以使用多卡加载方式

                       else:
                            # 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
                            # 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
                            from accelerate.utils import get_balanced_memory
                            max_memory = get_balanced_memory(model,
                                                             dtype=torch.int8 if self.load_in_8bit else None,
                                                             low_zero=False,
                                                             no_split_module_classes=model._no_split_modules)
                            self.device_map = infer_auto_device_map(model,
                                                                    dtype=torch.float16 if not self.load_in_8bit else torch.int8,
                                                                    max_memory=max_memory,
                                                                    no_split_module_classes=model._no_split_modules)

                    model = dispatch_model(model, device_map=self.device_map)
  • 未执行上述代码之前,显存占用为0
    在这里插入图片描述

  • 执行max_memory = get_balanced_memory(…):在这一部分代码中,通过调用 get_balanced_memory 函数来获取一个适当的内存分配方案,执行完后每个卡都会产生少量的显存占用
    在这里插入图片描述

在这里插入图片描述

  • 执行self.device_map = infer_auto_device_map(…):根据模型、数据类型、内存分配等信息来推断设备映射,将模型的不同部分分配到不同的设备上进行计算。
    在这里插入图片描述
  • 执行model = dispatch_model(model, device_map=self.device_map):根据生成的设备映射 将模型的不同部分分配到不同的设备上进行计算。这样,模型就可以利用多个GPU并行计算,以提高计算性能,模型权重被全部加载到显存。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/899251.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电脑远程接入软件可以进行文件传输吗?快解析内网穿透

电脑远程接入软件的出现&#xff0c;让我们可以在两台电脑之间进行交互和操作。但是&#xff0c;很多人对于这些软件能否进行文件传输还存在一些疑问。下面的文章将解答这个问题。 1.电脑远程接入软件可以进行文件传输。传统上&#xff0c;我们可能会通过传输线或者移动存储设…

听GPT 讲Prometheus源代码--promql/promdb

Prometheus的promql目录包含PromQL(Prometheus Query Language)的解析和执行代码: parser.go 定义PromQL语法结构和parser,用于将PromQL查询语句进行语法解析。 semantic.go 实现PromQL的语义分析,检查查询是否语法正确且语义合理。 engine.go 定义PromQL执行引擎的接口和数据结…

Chapter 15: Object-Oriented Programming | Python for Everybody 讲义笔记_En

文章目录 Python for Everybody课程简介Object-oriented programmingManaging larger programsGetting startedUsing objectsStarting with programsSubdividing a problemOur first Python objectClasses as typesObject lifecycleMultiple instancesInheritanceSummaryGlossa…

AlexNet阅读笔记

ImageNet classification with deep convolutional neural networks 原文链接&#xff1a;https://dl.acm.org/doi/abs/10.1145/3065386 中文翻译&#xff1a;https://blog.csdn.net/qq_38473254/article/details/132307508 使用深度卷积神经网络进行 ImageNet 分类 摘要 大…

all in one之安装群辉和加硬盘(第五章)

安装群辉和加硬盘 安装群辉 群辉系统选择和介绍 PVE7.1虚拟机安装黑群晖教程 安装教程2 下载对应的群辉系统&#xff0c;根据自己的电脑性能来选择并下载&#xff1a; 百度网盘链接:提取码: 4itr 下载链接【私人博客】 第一步&#xff0c;把下载的img镜像上传到pve 记住上…

idea2023 springboot2.7.5+mybatis+jsp 初学单表增删改查

创建项目 因为2.7.14使用量较少&#xff0c;特更改spring-boot为2.7.5版本 配置端口号 打开Sm01Application类&#xff0c;右键运行启动项目&#xff0c;或者按照如下箭头启动 启动后&#xff0c;控制台提示如下信息表示成功 此刻在浏览器中输入&#xff1a;http://lo…

UltraEdit2023代码编辑器下载安装教程

UltraEdit是深受编程人士喜爱的代码编辑器之一&#xff0c;简洁干净的工作界面&#xff0c;标配的语法高亮功能&#xff0c;代码折叠等高效编程功能&#xff0c;并且&#xff0c;还支持HTML、PHP和JavaScript等语法&#xff0c;让代码编辑、文档内容处理更加方便。 作为一款广…

文心一言 VS 讯飞星火 VS chatgpt (78)-- 算法导论7.4 2题

二、如果用go语言&#xff0c;证明&#xff1a;在最好情况下&#xff0c;快速排序的运行时间为 Ω(nlgn)。 文心一言&#xff1a; 在最好情况下&#xff0c;快速排序的运行时间为Ω(nlgn)的证明如下&#xff1a; 快速排序是一种分治算法&#xff0c;它将一个数组分成两个子数…

Android webrtc实战(一)录制本地视频并播放,附带详细的基础知识讲解

目录 一、创建PeerConnectionFactory 初始化 构建对象 二、创建AudioDeviceModule AudioDeviceModule JavaAudioDeviceModule 构建对象 setAudioAttributes setAudioFormat setAudioSource 创建录制视频相关对象 创建VideoSource 创建VideoCapturer 创建VideoTra…

一台电脑访问另一台电脑的虚拟机

打开虚拟机的ip&#xff1a;端口映射 虚拟网络编辑器设置&#xff1a;端口转发 访问虚拟机的主机IP 转发端口 ssh root另一台电脑的虚拟机主机ip: 9000注意&#xff1a;不是虚拟机的ip

通讯录实现【C语言】

目录 前言 一、整体逻辑分析 二、实现步骤 1、创建菜单和多次操作问题 2、创建通讯录 3、初始化通讯录 4、添加联系人 5、显示联系人 6、删除指定联系人 ​7、查找指定联系人 8、修改联系人信息 9、排序联系人信息 三、全部源码 前言 我们上期已经详细的介绍了自定…

docker学习(十五)docker安装MongoDB

什么是MongoDB? MongoDB 是一个开源的、面向文档的 NoSQL 数据库管理系统&#xff0c;它以高性能、灵活的数据存储方式而闻名。与传统的关系型数据库不同&#xff0c;MongoDB 采用了一种称为 BSON&#xff08;Binary JSON&#xff09;的二进制 JSON 格式来存储数据。它是一种非…

【AIGC 讯飞星火 | 百度AI|ChatGPT| 】智能对比

AI智能对比 &#x1f378; 前言&#x1f37a; 概念类对比&#x1f375; 讯飞&#x1f375; 百度AI&#x1f375; chatGPT &#x1f379; 功能类对比☕ 讯飞☕ 百度AI☕ chatGPT &#x1f943; 可输入字数对比&#x1f964; 百度AI&#x1f964; 讯飞&#x1f964; chatGPT &…

markdown编写微信公众号文章

微信公众号文章编写&#xff0c;暂不支持MarkDown的使用&#xff0c; 推荐工具&#xff1a; 墨滴 全称叫做&#xff1a; Makedown Nice&#xff0c;后面会以mdNice代替使用。 通过官网的写文章&#xff0c;支持在线编译安装chrome浏览器插件&#xff0c; 支持在微信公众号编译…

字符串旋转(1)

目录 ​编辑 题目要求&#x1f60d;&#xff1a; 题目内容❤&#xff1a; 题目分析&#x1f4da;&#xff1a; 主函数部分&#x1f4d5;&#xff1a;​编辑 方法一&#x1f412;&#xff1a; 方法二&#x1f412;&#x1f412;&#xff1a; 方法三&#x1f412;&#x1f…

Day978.如何在移动App中使用OAuth 2.0? -OAuth 2.0

如何在移动App中使用OAuth 2.0&#xff1f; Hi&#xff0c;我是阿昌&#xff0c;今天学习记录的是关于如何在移动App中使用OAuth 2.0&#xff1f;的内容。 除了 Web 应用外&#xff0c;现实环境中还有非常多的移动 App。 那么&#xff0c;在移动 App 中&#xff0c;能不能使…

手把手带你设计接口自动化测试用例(一):提取接口信息并分析

1、测试行业市场现状 随着市场需求的变化&#xff0c;大部分企业在招聘测试人员时&#xff0c;都会提出接口自动化测试的相关要求&#xff0c;为什么会这样呢&#xff1f; 目前&#xff0c;软件构架基本上都是前后端分离的&#xff0c;软件的主要功能由服务端提供。从整个软件…

生成国密SM2密钥对

在线生成国密密钥对 生成的密钥对要妥善保管&#xff0c;丢失是无法找回的。

windows无法与设备或主DNS服务器通信

今天电脑连上wifi后发现qq可以登录,爱奇艺也可以正常使用,但是就浏览器不能用,不管哪个网站都是无法访问,点击下面的Windows网络诊断后发现是因为windows无法与设备或主DNS服务器通信 1.右下角右键wifi图标,打开网络和internet设置 2.点击网络和共享中心 3. 点击更改适配器设置…

(杭电多校)2023“钉耙编程”中国大学生算法设计超级联赛(9)

1002 shortest path 记忆化搜索可以用 map 实现&#xff0c;频繁读取而不考虑元素顺序的可以使用 unordered_map &#xff0c;有效降低时间空间复杂度 dfs(n/2)n%21,其中n%2表示将n变为偶数的次数,1表示操作n/2,dfs(n/2)即表示将n/2变为1的次数 AC代码: #include<iostre…