chatglm2-6b在P40上做LORA微调 | 京东云技术团队

news2025/4/18 2:18:50

背景:

目前,大模型的技术应用已经遍地开花。最快的应用方式无非是利用自有垂直领域的数据进行模型微调。chatglm2-6b在国内开源的大模型上,效果比较突出。本文章分享的内容是用chatglm2-6b模型在集团EA的P40机器上进行垂直领域的LORA微调。

一、chatglm2-6b介绍

github: https://github.com/THUDM/ChatGLM2-6B

chatglm2-6b相比于chatglm有几方面的提升:

1. 性能提升: 相比初代模型,升级了 ChatGLM2-6B 的基座模型,同时在各项数据集评测上取得了不错的成绩;

2. 更长的上下文: 我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练;

3. 更高效的推理: 基于 Multi-Query Attention 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%;

4. 更开放的协议:ChatGLM2-6B 权重对学术研究完全开放,在填写问卷进行登记后亦允许免费商业使用。

二、微调环境介绍

2.1 性能要求

推理这块,chatglm2-6b在精度是fp16上只需要14G的显存,所以P40是可以cover的。

EA上P40显卡的配置如下:

2.2 镜像环境

做微调之前,需要编译环境进行配置,我这块用的是docker镜像的方式来加载镜像环境,具体配置如下:

FROM base-clone-mamba-py37-cuda11.0-gpu

# mpich
RUN yum install mpich  

# create my own environment
RUN conda create -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ --override --yes --name py39 python=3.9
# display my own environment in Launcher
RUN source activate py39 \
    && conda install --yes --quiet ipykernel \
    && python -m ipykernel install --name py39 --display-name "py39"

# install your own requirement package
RUN source activate py39 \
    && conda install -y -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ \
    pytorch  torchvision torchaudio faiss-gpu \
    && pip install --no-cache-dir  --ignore-installed -i https://pypi.tuna.tsinghua.edu.cn/simple \
    protobuf \
    streamlit \
    transformers==4.29.1 \
    cpm_kernels \
    mdtex2html \
    gradio==3.28.3 \
	sentencepiece \
	accelerate \
	langchain \
    pymupdf \
	unstructured[local-inference] \
	layoutparser[layoutmodels,tesseract] \
	nltk~=3.8.1 \
	sentence-transformers \
	beautifulsoup4 \
	icetk \
	fastapi~=0.95.0 \
	uvicorn~=0.21.1 \
	pypinyin~=0.48.0 \
    click~=8.1.3 \
    tabulate \
    feedparser \
    azure-core \
    openai \
    pydantic~=1.10.7 \
    starlette~=0.26.1 \
    numpy~=1.23.5 \
    tqdm~=4.65.0 \
    requests~=2.28.2 \
    rouge_chinese \
    jieba \
    datasets \
    deepspeed \
	pdf2image \
	urllib3==1.26.15 \
    tenacity~=8.2.2 \
    autopep8 \
    paddleocr \
    mpi4py \
    tiktoken

如果需要使用deepspeed方式来训练, EA上缺少mpich信息传递工具包,需要自己手动安装。

2.3 模型下载

huggingface地址: https://huggingface.co/THUDM/chatglm2-6b/tree/main

三、LORA微调

3.1 LORA介绍

paper: https://arxiv.org/pdf/2106.09685.pdf

LORA(Low-Rank Adaptation of Large Language Models)微调方法: 冻结预训练好的模型权重参数,在冻结原模型参数的情况下,通过往模型中加入额外的网络层,并只训练这些新增的网络层参数。

LoRA 的思想:

  • 在原始 PLM (Pre-trained Language Model) 旁边增加一个旁路,做一个降维再升维的操作。
  • 训练的时候固定 PLM 的参数,只训练降维矩阵A与升维矩B。而模型的输入输出维度不变,输出时将BA与 PLM 的参数叠加。
  • 用随机高斯分布初始化A,用 0 矩阵初始化B,保证训练的开始此旁路矩阵依然是 0 矩阵。

3.2 微调

huggingface提供的peft工具可以方便微调PLM模型,这里也是采用的peft工具来创建LORA。

peft的github: https://gitcode.net/mirrors/huggingface/peft?utm_source=csdn_github_accelerator

加载模型和lora微调:

    # load model
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
    model = AutoModel.from_pretrained(args.model_dir, trust_remote_code=True)
    
    print("tokenizer:", tokenizer)
    
    # get LoRA model
    config = LoraConfig(
        r=args.lora_r,
        lora_alpha=32,
        lora_dropout=0.1,
        bias="none",)
    
    # 加载lora模型
    model = get_peft_model(model, config)
    # 半精度方式
    model = model.half().to(device)

这里需要注意的是,用huggingface加载本地模型,需要创建work文件,EA上没有权限在没有在.cache创建,这里需要自己先制定work路径。

import os
os.environ['TRANSFORMERS_CACHE'] = os.path.dirname(os.path.abspath(__file__))+"/work/"
os.environ['HF_MODULES_CACHE'] = os.path.dirname(os.path.abspath(__file__))+"/work/"



如果需要用deepspeed方式训练,选择你需要的zero-stage方式:

    conf = {"train_micro_batch_size_per_gpu": args.train_batch_size,
            "gradient_accumulation_steps": args.gradient_accumulation_steps,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-5,
                    "betas": [
                        0.9,
                        0.95
                    ],
                    "eps": 1e-8,
                    "weight_decay": 5e-4
                }
            },
            "fp16": {
                "enabled": True
            },
            "zero_optimization": {
                "stage": 1,
                "offload_optimizer": {
                    "device": "cpu",
                    "pin_memory": True
                },
                "allgather_partitions": True,
                "allgather_bucket_size": 2e8,
                "overlap_comm": True,
                "reduce_scatter": True,
                "reduce_bucket_size": 2e8,
                "contiguous_gradients": True
            },
            "steps_per_print": args.log_steps
            }

其他都是数据处理处理方面的工作,需要关注的就是怎么去构建prompt,个人认为在领域内做微调构建prompt非常重要,最终对模型的影响也比较大。

四、微调结果

目前模型还在finetune中,batch=1,epoch=3,已经迭代一轮。

作者:京东零售 郑少强

来源:京东云开发者社区 转载请注明来源

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/982209.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智慧铁路:机车整备场数字孪生

机车整备场是铁路运输系统中的重要组成部分,它承担着机车的维修、保养和整备工作,对保障铁路运输的运维和安全起着至关重要的作用。 随着铁路运输的发展、机车技术的不断进步,以及数字化转型的不断推进,数字孪生技术在机车整备场…

LeetCode刷题笔记【27】:贪心算法专题-5(无重叠区间、划分字母区间、合并区间)

文章目录 前置知识435. 无重叠区间题目描述参考<452. 用最少数量的箭引爆气球>, 间接求解直接求"重叠区间数量" 763.划分字母区间题目描述贪心 - 建立"最后一个当前字母"数组优化marker创建的过程 56. 合并区间题目描述解题思路代码① 如果有重合就合…

【业务功能篇99】微服务-springcloud-springboot-电商订单模块-生成订单服务-锁定库存

八、生成订单 一个是需要生成订单信息一个是需要生成订单项信息。具体的核心代码为 /*** 创建订单的方法* param vo* return*/private OrderCreateTO createOrder(OrderSubmitVO vo) {OrderCreateTO createTO new OrderCreateTO();// 创建订单OrderEntity orderEntity build…

echarts饼图label自定义样式

生成的options {"tooltip": {"trigger": "item","axisPointer": {"type": "shadow"},"backgroundColor": "rgba(9, 24, 48, 0.5)","borderColor": "rgba(255,255,255,0.4)&q…

Python Qt学习(十)一个简易的POP3邮件客户端

公司把126这类的邮箱网站都封了&#xff0c;正好现在无事&#xff0c;加之&#xff0c;算是一个对这俩周学习Qt的一个总结吧。遂写了这么一个简易的通过POP3协议接收126邮件的客户端。 源代码&#xff1a; # -*- coding: utf-8 -*-# Form implementation generated from read…

OpenCV图像处理——矩形(Rect)类的常用操作

1.Rect类 创建类 Rect类成员变量x、y、width、height&#xff0c;分别为左上角点的坐标和矩形的宽和高。 创建一个Rect对象Rect,并在图像上画该矩形框。 cv::Rect rect(100, 50, 500, 500);cv::Mat cv_src cv::imread("11.JPG");cv::rectangle(cv_src, rect, cv:…

每一座屎山代码背后,都藏着一堆熟读代码规范的研发

&#x1f449;导读 韩寒在《他的国》中写道&#xff1a;“我们懂很多道理&#xff0c;却依然过不好这一生”&#xff0c;人们虽然知道很多道理&#xff0c;但并不一定能将这些道理应用到实际生活中。这种现象在生活中很常见&#xff0c;我们听了很多的成功学的道理&#xff0c;…

接入 NVIDIA A100、吞吐量提高 10 倍!Milvus GPU 版本使用指南

Milvus 2.3 正式支持 NVIDIA A100&#xff01; 作为为数不多的支持 GPU 的向量数据库产品&#xff0c;Milvus 2.3 在吞吐量和低延迟方面都带来了显著的变化&#xff0c;尤其是与此前的 CPU 版本相比&#xff0c;不仅吞吐量提高了 10 倍&#xff0c;还能将延迟控制在极低的水准。…

ChatGLM2-6B 部署

引言 这是ChatGLM2-6B 部署的阅读笔记&#xff0c;主要介绍了ChatGLM2-6B模型的部署和一些原理的简单解释。 ChatGLM-6B 它是单卡开源的对话模型。 充分的中英双语预训练 较低的部署门槛 FP16半精度下&#xff0c;需要至少13G的显存进行推理&#xff0c;甚至可以进一步降低…

罕见病 对称性脂肪瘤(MSL) 马德龙病

如果你体内脂肪瘤分布大致如下 而且个数不断增多 这篇文章适合你 症状 脂肪瘤个数一直增加 而且很对称 比如: 左手臂一个 右手臂一个 别名 多发性对称性脂肪增多症 Multiple symmetric lipomatosis (MSL) 多发性对称性脂肪瘤&#xff08;MSL&#xff09; 脂肪瘤 马德龙病(…

大场景的倾斜摄影三维模型OBJ格式轻量化处理处理关键处理技术分析

大场景的倾斜摄影三维模型OBJ格式轻量化处理处理关键处理技术分析 大场景的倾斜摄影三维模型是指通过航空或地面摄影获取的大范围、高分辨率的地理环境数据。为了在虚拟环境中加载和渲染这些模型&#xff0c;需要对其进行OBJ格式的轻量化处理。本文将分析大场景的倾斜摄影三维模…

SSRF漏洞实战

文章目录 SSRF概述SSRF原理SSRF 危害PHP复现SSRF漏洞检测端口扫描内网Web应用指纹识别攻击内网应用读取本地文件 Weblogic SSRF--Getshell复现SSRF攻击Redis原理漏洞检测端口扫描复现翻车&#xff0c;请看官方复现教程注入HTTP头&#xff0c;利用Redis反弹shell SSRF防御过滤输…

软路由的负载均衡设置:优化网络性能和带宽利用率

在现代网络环境中&#xff0c;提升网络性能和最大化带宽利用率至关重要。通过合理配置软路由IP的负载均衡设置&#xff0c;可以有效地实现这一目标&#xff0c;并提高整体稳定性与效果。本文将详细介绍如何进行软路由IP的负载均衡设置&#xff0c;从而优化网络表现、增加带宽利…

软件架构设计(六) 软件架构风格-MDA(模型驱动架构)

概念 模型驱动架构MDA, 全称叫做Model Driven Architecture。 Model:表示客观事物的抽象表示Architecture:表示构成系统的部件,连接件及其约束的规约Model Driven: 使用模型完成软件的分析,设计,构建,部署和维护等 开发活动MDA起源于分离系统规约和平台实现的思想。之前…

Python入门学习13(面向对象)

一、类的定义和使用 类的使用语法&#xff1a; 创建类对象的语法&#xff1a; ​​​​​​​ class Student:name None #学生的名字age None #学生的年龄def say_hi(self):print(f"Hi大家好&#xff0c;我是{self.name}")stu Student() stu.name &q…

软件系统平台验收测试报告

验收测试 一、验收测试 软件项目验收测试依据招投标文件以及相关行业标准、国家标准、法律法规等对软件的功能性、易用性、可靠性、兼容性、维护性、可移植性和用户文档等进行检测&#xff0c;对软件项目的质量进行科学的评价&#xff0c;为项目验收提供依据。 1、服务内容 …

成功解决OSError: [WinError 1455] 页面文件太小,无法完成操作

最近写了个训练文件&#xff0c;昨天在运行的时候都是好好的&#xff0c;今天一运行就报错了&#xff0c;不得不说&#xff0c;有点点奇怪。 OSError: [WinError 1455] 页面文件太小&#xff0c;无法完成操作。 Error loading "D:\AI\Anaconda\anaconda3\envs\torch1.8\li…

嵌入式学习笔记(16)反汇编工具objdump

2.4.1反汇编的原理&为什么要用反汇编 arm-linux-objdump -D led.elf > led_elf.dis objdump是gcc工具链中的反汇编工具&#xff0c;作用是由编译链接好的elf格式的可执行程序反过来得到汇编源代码 -D表示反汇编 > 左边的是elf可执行程序&#xff08;反汇编的源&am…

Linux RPM JDK升级

以JDK1.8升级JDK17为例 上传jdk17安装包到linux服务器 检查jdk版本 rpm -qa|grep jdk 删除查询到的jdk rpm -e --nodeps jdk1.8-1.8.0_201-fcs.x86_64 删除完毕后安装新的jdk rpm -ivh jdk-17_linux-x64_bin.rpm 检查jdk版本 java -version

Matlab信号处理3:fft(快速傅里叶变换)标准使用方式

Fs 1000; % 采样频率 T 1/Fs; % 采样周期&#xff1a;0.001s L 1500; % 信号长度 t (0:L-1)*T; % 时间向量. 时间向量从0开始递增&#xff0c;0s~1.499sS 0.7*sin(2*pi*50*t) sin(2*pi*120*t); % 模拟原信号 X S 2*randn(size(t)); …