【深度学习】微调ChatGlm3-6b

news2024/9/29 13:27:05

1.前言

        指令微调ChatGlm3-6b。微调教程在github地址中给出,微调环境是Qwen提供的docker镜像为环境。

        镜像获取方式:docker pull qwenllm/qwen:cu117

        github地址:https://github.com/liucongg/ChatGLM-Finetuning

2.微调过程

        github地址中的教程给出了详细的微调过程。如果使用预先准备的docker,微调则更为方便。在实践时候,用了Qwen提供的docker。在使用容器微调时,不需要让容器中开启服务,所以需要以官方提供的镜像为基础,再做一个镜像。 本次实践中使用的镜像与微调Qwen-1.8B的镜像一致。做镜像的具体步骤与命令均在【微调Qwen1.8B】教程中给出。

2-1 环境准备

        开发机器中现有库的版本与requestment.txt中指定的版本不一致,所以使用docker镜像作为开发环境,docker的环境需要安装指定版本的deepspeed和tensorboard。

#运行镜像生成容器 挂载模型
docker run --gpus all -v /ssd/dongzhenheng/LLM/chatglm3-6b:/data/shared/Qwen/chatglm3-6b -it qwenllm/qwen:cu117 bash 
#下载
pip install deepspeed==0.11.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple

 2-2 数据准备

        原github项目中给出了微调数据的格式,每条都包含instruction、input、output的json格式数据。

{"instruction": "你现在是一个信息抽取模型,请你帮我抽取出关系内容为\"性能故障\", \"部件故障\", \"组成\"和 \"检测工具\"的相关三元组,三元组内部用\"_\"连接,三元组之间用\\n分割。文本:", "input": "故障现象:发动机水温高,风扇始终是低速转动,高速档不工作,开空调尤其如此。", "output": "发动机_部件故障_水温高\n风扇_部件故障_低速转动"}
{"instruction": "你现在是一个信息抽取模型,请你帮我抽取出关系内容为\"性能故障\", \"部件故障\", \"组成\"和 \"检测工具\"的相关三元组,三元组内部用\"_\"连接,三元组之间用\\n分割。文本:", "input": "957号汽车故障报告故障现象一辆2007年长丰猎豹飞腾6400越野车,行驶里程27000km。维修室内灯不亮,更换室内灯泡后,发现四个转向灯常亮,后雨刮和后喷水常工作,开关不起作用。", "output": "开关_部件故障_不起作用\n维修室内灯_部件故障_不亮"}
#存放路径
sava_path = '/data/zhenhengdong/WORk/Fine-tuning/ChatGlm3-6B/Datasets'
def write_method(item):
    #建立data.jsonl文件,以追加的方式写入数据
    with jsonlines.open(sava_path + 'data.jsonl', mode = 'a') as json_writer:
        json_writer.write(item)
#读文件
data = pd.read_csv(data_path)
for index in range(len(data)):
    temp_data = data.iloc[index]
    temp_dict = {
        "instruction":prompt.replace('\n',''),
        "input": data.iloc[index]['输入query'],
        "output":data.iloc[index]['输出结果']
        }
    write_method(temp_dict)

2-3 代码准备

        将gitclone下来的ChatGLM-Finetuning-master项目、准备的数据复制到docker 容器中。

docker cp /data/zhenhengdong/WORk/Fine-tuning/ChatGlm3-6B/Codes/ChatGLM-Finetuning-master a27aaa4f78dc:/data/shared/Qwen/

        github教程中给出了多种训练方式,Freeze方法、PT方法、Lora方法、全参方法。在微调时,采用了 Lora方法。

        准备run_train.sh,在微调时直接运行即可。训练代码均采用DeepSpeed进行训练,可设置参数包含train_path、model_name_or_path、mode、train_type、lora_dim、lora_alpha、lora_dropout、lora_module_name、ds_file、num_train_epochs、per_device_train_batch_size、gradient_accumulation_steps、output_dir等, 可根据自己的任务配置。

        Datasetsdata.json是按照2-2中数据格式准备的微调数据。

        通过CUDA_VISIBLE_DEVICES控制具体哪几块卡进行训练,如果不加该参数,表示使用运行机器上所有卡进行训练。CUDA_VISIBLE_DEVICES=0表示使用0号GPU,也可设置为

CUDA_VISIBLE_DEVICES=0,1,2,3
CUDA_VISIBLE_DEVICES=0 deepspeed --master_port 5200 train.py \
                --train_path Datasetsdata.json \
                --model_name_or_path /data/shared/Qwen/chatglm3-6b \
                --per_device_train_batch_size 1 \
                --max_len 1560 \
                --max_src_len 1024 \
                --learning_rate 1e-4 \
                --weight_decay 0.1 \
                --num_train_epochs 2 \
                --gradient_accumulation_steps 4 \
                --warmup_ratio 0.1 \
                --mode glm3 \
                --lora_dim 16 \
                --lora_alpha 64 \
                --lora_dropout 0.1 \
                --lora_module_name "query_key_value,dense_h_to_4h,dense_4h_to_h,dense" \
                --seed 1234 \
                --ds_file ds_zero2_no_offload.json \
                --gradient_checkpointing \
                --show_loss_step 10 \
                --output_dir output-glm3

        训练过程如下: 

 2-4 微调输出

        微调之后会有一个output-glm3的文件夹。

        output-glm3文件夹中有每一轮训练保存的模型

2-5 merge

        在github教程中作者提供了merge.py文件,可以使用merge.py文件进行合并。在合并时,也可自己写merge代码。

        将微调的模型从docker容器中cp到ssd目录下,准备与原模型合并。

import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModel
#加载原模型
base_model = '/ssd/dongzhenheng/LLM/chatglm3-6b'
base_model = AutoModel.from_pretrained(base_model, trust_remote_code=True).cuda()
#加载微调的模型
lora_model_path = '/ssd/dongzhenheng/Work/ChatGLM3-6B微调/epoch-2-step-84'
lora_model = PeftModel.from_pretrained(base_model,lora_model_path, torch_dtype=torch.float16)
lora_model.to("cpu")
#合并
merged_model = lora_model.merge_and_unload()
#合并的模型存储
new_model_directory = '/ssd/dongzhenheng/Work/ChatGLM3-6B微调'
merged_model.save_pretrained(new_model_directory, max_shard_size="2048MB", safe_serialization=True)

 2-6 验证

        使用合并的模型对测试数据验证。

#加载模型
new_model_directory = '/ssd/dongzhenheng/Work/ChatGLM3-6B微调'
tokenizer = AutoTokenizer.from_pretrained(new_model_directory, trust_remote_code=True)
model = AutoModel.from_pretrained(new_model_directory, trust_remote_code=True).cuda()
model.eval()
#输入
instruction = "你现在是一个信息抽取模型,请你帮我抽取出关系内容为\"性能故障\", \"部件故障\", \"组成\"和 \"检测工具\"的相关三元组,三元组内部用\"_\"连接,三元组之间用\\n分割。文本:"
input = "故障现象:发动机水温高,风扇始终是低速转动,高速档不工作,开空调尤其如此。 

#验证
response, _ = model.chat(tokenizer, instruction+input_data, history=None)
print(response)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1473099.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络防御-VPN概述

目录 VPN的概述VPN的分类根据建设的单位不同分类根据组网方式不同分类根据VPN技术实现的层次来进行分类 VPN其他常用技术身份认证技术 --- 身份认证是VPN技术的前提。加解密技术 --- 以此来抵抗网络中的一些被动攻击数据认证技术 --- 验货 --- 保证数据的完整性密钥管理技术 VP…

CS_上线三层跨网段机器(完整过程还原)

以前讲过用cs_smb_beacon上线不出网机器,但是真实的网络拓扑肯定不止这么一层的网络! 所以我就来搭建一个复杂一点的网络环境!! 当然了,这三台电脑之间都是不同的网段,(但是同属于一个域环境&a…

C# 学习第二弹

一、变量 存储区(内存)中的一个存储单元 (一)变量的声明和初始化 1、声明变量——根据类型分配空间 ①声明变量的方式 —变量类型 变量名 数值; —变量类型 变量名; 变量名 数值; —变…

【Rust】简介、安装和编译

文章目录 一、Rust简介二、Rust 安装三、Rust 程序结构3.1 模块(Modules):3.2 函数(Functions):3.3 变量(Variables):3.4 控制流(Control Flow)&a…

Coursera吴恩达机器学习专项课程02:Advanced Learning Algorithms 笔记 Week03

Week 03 of Advanced Learning Algorithms 笔者在2022年7月份取得这门课的证书,现在(2024年2月25日)才想起来将笔记发布到博客上。 Website: https://www.coursera.org/learn/advanced-learning-algorithms?specializationmachine-learnin…

Centos配置SSH并禁止密码登录

CentOS8 配置SSH使用密钥登录并禁止密码登录 一、概念 SSH 为 Secure Shell 的缩写,SSH 为建立在应用层基础上的安全协议。SSH 是较可靠,专为远程登录会话和其他网络服务提供安全性的协议。 SSH提供两个级别的认证: 基于口令的认证 基于密钥的认证 基本使…

SkyWalking微服务链路追踪实战

目录 skywalking是什么? Skywalking主要功能特性 Skywalking整体架构 SkyWalking 环境搭建部署 SkyWalking快速开始 SkyWalking Agent追踪微服务 通过jar包方式接入 在IDEA中使用Skywalking Skywalking跨多个微服务追踪 Skywalking集成日志框架 Skywalki…

【c语言】if 选择语句

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:C语言 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步&…

计网Lesson16 - TCP选择重传和流量控制

文章目录 1. 选择性确认(SACK)2. TCP流量控制2.1 基本情况2.2 特殊情况 1. 选择性确认(SACK) TCP通信中,发送序列中的某一包丢失(1,2,3,4,5 中 3 丢失&#…

Promise 介绍与基本使用 - 学习笔记

Promise 介绍与基本使用 1、 Promise 是什么?2、创建 Promise 实例对象3、Promise 实例方法4、Promise 的基本工作流程5、实例方法6、静态方法7、async 和 await7.1、关键字7.2、实例7.3、区别7.4、为什么使用 async/await 比较好? 1、 Promise 是什么&a…

NUS神经网络生成我感觉解读过于夸大了

网上对其解读有点过了,只是合成了最后标准化层的参数,或者是更多的其他层参数。而不是网络结构。对于新任务下的网络结构以及参数如何生成,应该是做不到的,论文意义有限。 论文片段:我们提出了神经网络扩散&#xff0…

数据可视化引领智慧仓储新时代

随着科技的飞速发展,数据可视化已然成为智慧仓储领域的璀璨明珠,其强大的功能和多面的作用让智慧仓储焕发出勃勃生机。让我们一同探索,数据可视化究竟在智慧仓储中起到了怎样的作用。下面我就以可视化从业者的角度来简单谈谈这个话题。 在这…

【练习——打印每一位数】

打印一个数的每一位 举个例子:我们现在要求打印出123的每一位数字。我们需要去想123%10等于3,就可以把3单独打印出来了,然后再将123/10可以得到12,将12%10就可以打印出2,而我们最后想打印出1,只需要1%10就…

数据隐私安全趋势

在当今社交媒体和开源开发的世界中,共享似乎已成为社会常态。毕竟,我们都被教导分享就是关怀。这不仅适用于个人,也适用于公司:无论是有意在社交媒体帐户和公司网站上,还是无意中通过员工的行为,公司可能会…

树莓派使用git clone时报错failed: The TLS connection was non-properly terminated.

fatal: unable to access https://github.com/jacksonliam/mjpg-streamer.git/: gnutls_handshake() failed: The TLS connection was non-properly terminated. 原因:权限不足 解决办法:sudo git clone 加对应网址。 sudo git clone https://github.co…

golang gin单独部署vue3.0前后端分离应用

概述 因为公司最近的项目前端使用vue 3.0,后端api使用golang gin框架。测试通过后,博文记录,用于备忘。 步骤 npm run build,构建出前端项目的dist目录,dist目录的结构具体如下图 将dist目录复制到后端程序同级目录…

排序(9.17)

1.排序的概念及其运用 1.1排序的概念 排序 :所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 稳定性 :假定在待排序的记录序列中,存在多个具有相同的关键字的记…

网络安全与信创产业发展:构建数字时代的护城河

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua,在这里我会分享我的知识和经验。&#x…

CMU15445实验总结(Spring 2023)

CMU15445实验总结(Spring 2023) 背景 菜鸟博主是2024届毕业生,学历背景太差,导致23年秋招无果,准备奋战春招。此前有读过LevelDB源码的经历,对数据库的了解也仅限于LevelDB。奔着”有对比才能学的深“的理念,以及缓解…

Java之SpringMVC源码详解

SpringMVC源码 一、SpringMVC的基本结构 1.MVC简介 以前的纯Servlet的处理方式: Overrideprotected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {String type req.getParameter(Constant.REQUEST_PA…