15chatGLM3半精度微调

news2024/11/14 19:11:08

1 模型准备       

         数据依然使用之前的数据,但是模型部分我们使用chatglb-3,该模型大小6B,如果微调的话需要24*4 = 96GB,硬件要求很高,那么我们使用半精度微调策略进行调试,半精度微调有很多坑啊,注意别踩到了;

#依赖 pip install modelscope

# pip install transformers==4.40.2, 不知道为什么使用之前的版本推理有问题!

模型

http://chatGLM3

模型文件很大,综合十几个G的,自己试试吧;

2 模型介绍

如果假设 ChatGLM3 是 ChatGLM 系列的后续版本,那么可以推测它可能是对现有 ChatGLM 模型的进一步改进和扩展。这样的改进可能包括但不限于以下几个方面:

  1. 模型规模:增加模型的参数量,以提高模型的表达能力和泛化能力。
  2. 架构改进:引入新的架构设计,例如更先进的注意力机制或其他创新技术,以提高模型的性能。
  3. 训练数据:使用更多的训练数据,特别是高质量的对话数据,以增强模型的理解和生成能力。
  4. 优化技术:采用更高效的训练方法和优化算法,以加速训练过程并提高模型的收敛速度。
  5. 多模态能力:增强模型处理多种模态数据(如图像、视频等)的能力,使其成为一个更全面的多模态模型。
  6. 安全性与伦理:加强对模型输出的安全性和伦理性的控制,确保生成的内容更加可靠和安全。

ChatGLM2与ChatGLM3模型架构是完全一致的,ChatGLM与后继者结构不同。可见ChatGLM3相对于ChatGLM2没有模型架构上的改进。

相对于ChatGLM,ChatGLM2、ChatGLM3模型上的变化:

  1. 词表的大小从ChatGLM的150528缩小为65024 (一个直观的体验是ChatGLM2、3加载比ChatGLM快不少)
  2. 位置编码从每个GLMBlock一份提升为全局一份
  3. SelfAttention之后的前馈网络有不同。ChatGLM用GELU(Gaussian Error Linear Unit)做激活;ChatGLM用Swish-1做激活。而且ChatGLM2、3应该是修正了之前的一个bug,因为GLU(Gated Linear Unit)本质上一半的入参是用来做门控制的,不需要输出到下层,所以ChatGLM2、3看起来前后维度不一致(27392->13696)反而是正确的。

model 

使用Lora进行微调:

chatGLM进行切词会生成:

from transformers import AutoTokenizer, AutoModel
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

ds = Dataset.load_from_disk("../data/")
# trust_remote_code=True 注意添加
tokenizer = AutoTokenizer.from_pretrained("../model/chatglm3-6b/", trust_remote_code=True)

def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = "\n".join([example["instruction"], example["input"]]).strip()     # query
    instruction = tokenizer.build_chat_input(instruction, history=[], role="user")  # [gMASK]sop<|user|> \n query<|assistant|>
    response = tokenizer("\n" + example["output"], add_special_tokens=False)        # \n response, 缺少eos token
    input_ids = instruction["input_ids"][0].numpy().tolist() + response["input_ids"] + [tokenizer.eos_token_id]
    attention_mask = instruction["attention_mask"][0].numpy().tolist() + response["attention_mask"] + [1]
    labels = [-100] * len(instruction["input_ids"][0].numpy().tolist()) + response["input_ids"] + [tokenizer.eos_token_id]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds


import torch


# 多卡情况,可以去掉device_map="auto",否则会将模型拆开
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="../model/chatglm3-6b/",
                                             trust_remote_code=True, 
                                             torch_dtype=torch.bfloat16)

from peft import LoraConfig, TaskType, get_peft_model, PeftModel

config = LoraConfig(target_modules=["query_key_value"], modules_to_save=["post_attention_layernorm"])
config

model = get_peft_model(model, config)

model.print_trainable_parameters()

from transformers.trainer_callback import TrainerCallback
import matplotlib.pyplot as plt

class PrintLossCallback(TrainerCallback):
    
    def __init__(self):
        self.losses = []
        self.steps = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        # 打印训练过程中的日志信息
        try:
            if logs is not None:
                print(f"Step {state.global_step}: Loss={logs['loss']:.4f}, Learning Rate={logs['learning_rate']:.6f}")
                self.losses.append(logs['loss'])
                self.steps.append(state.global_step)

        except Exception as e :
            print(f'on_log error {e}')
    
    def plot_losses(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.steps, self.losses, label='Training Loss')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Time')
        plt.legend()
        plt.show()


args = TrainingArguments(
    output_dir="./chatbot_gml3",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=16,
    logging_steps=10,
    num_train_epochs=1,
    learning_rate=1e-4,
    remove_unused_columns=False,
    save_strategy="epoch"
)

plot_losses_callback = PrintLossCallback()

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds,#.select(range(6000)),
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[plot_losses_callback]  # 注册自定义回调
)
if torch.cuda.is_available():
    trainer.model = trainer.model.to("cuda")
# 训练模型
trainer.train()

可以看到loss终于到达了1.9; 

效果还可以,可以作为一个闲聊的机器人!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2101453.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

只会SQL语句,可以做什么工作?

1、SQL是什么 首先简单介绍一下SQL&#xff08;Structured Query Language&#xff09;&#xff0c;是一种可以进行数据提取、聚合、分析&#xff0c;并对数据库进行构建和修改的编程语言。 相对来说&#xff0c;SQL上手非常容易&#xff0c;因为语法结构比较固定&#xff0c…

第一性原理计算从定义到场景到硬件配置详细讲解

第一性原理计算&#xff0c;又称为从头计算&#xff08;The Ab initio Calculation&#xff09;&#xff0c;是一种基于量子力学原理&#xff0c;通过计算机模拟来预测材料、分子、固体等体系性质的方法。这种方法的核心思想是不依赖于实验数据或经验参数&#xff0c;而是直接从…

如何纯手动的创建SpringBoot工程?

1、打开项目结构 2、new 一个新模块 3、所需全部选配好 4、 创建好之后&#xff0c;目录如下 5、在pom文件中&#xff0c;做第一件事情&#xff08;让当前的工程继承一个父工程&#xff09; &#xff08;这是一个固定的写法&#xff1a;spring-boot-starter-parent&#xff09;…

JavaWeb - Maven

Maven apache旗下的一个来源项目&#xff0c;一款用于管理和构建java项目的工具&#xff0c;它基于项目对象模型&#xff08;POM&#xff09;的概念&#xff0c;通过一小段描述信息来管理项目的构建。 作用 安装 解压官网下载的压缩包 配置本地仓库&#xff0c;修改conf/se…

接口请求400

接口请求400 在Web开发中&#xff0c;接口请求错误是开发者经常遇到的问题之一。其中&#xff0c;400错误&#xff08;Bad Request&#xff09;尤为常见&#xff0c;它表明发送到服务器的请求有误或不能被服务器理解。本文将深入探讨接口请求400错误&#xff0c;从常见报错问题…

springcloud微服务入门

1.架构的演变 目前我们接触的比较多的是单体架构&#xff0c;指的是将所有功能集中在一个项目中开发&#xff0c;打成一个包部署。 这样的架构优点在于&#xff0c;架构简单&#xff0c;把各个功能集中在一起方便操作管理&#xff0c;部署成本也比较低但是缺点也是很明显&#…

让AI给你写代码(10.1): 按接口编程的思想,统一利用内部和外部的接口,逐步扩展和提升AI编程能力

先总结一下AI编程小助手已具备的能力&#xff0c;目前AI小助手已经可以利用本地知识库和在线大模型&#xff08;我们用的是qwen&#xff09;生成可测试&#xff0c;可执行代码的能力&#xff08;具体流程参考从让AI给你写代码&#xff08;9.1&#xff09;&#xff09;&#xff…

※※Leetcode Hot 100刷题记录 -Day8(和为k的子数组)

问题描述&#xff1a; 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的子数组的个数 。子数组是数组中元素的连续非空序列。 示例 1&#xff1a; 输入&#xff1a;nums [1,1,1], k 2 输出&#xff1a;2示例 2&#xff1a; 输入&#xff1a…

java开发面试:AOT有什么优缺点/适用于什么场景/AOT和JIT的对比、逃逸分析和对象存储在堆上的关系、高并发中的集合有哪些问题

JDK9引入了AOT编译模式。 AOT 有什么优点&#xff1f;适用于什么场景&#xff1f; JDK 9 引入了一种新的编译模式 AOT(Ahead of Time Compilation) 。 和 JIT 不同的是&#xff0c;这种编译模式会在程序被执行前就将其编译成机器码&#xff0c;属于静态编译&#xff08;C、 C…

【Redis详解】Redis安装+主从复制+哨兵模式+Redis Cluster

目录 一、Redis简介 1.1 关系型数据库和NoSQL数据库 二、Redis安装 2.1 rpm 安装 2.2 源码安装 三、Redis基本操作 四、Redis主从复制 4.1 配置主从同步 4.2 主从同步过程 五、Redis高可用--哨兵模式 5.1 哨兵的实验过程 六、数据保留 七、Redis Cluster 7.1 部署…

【办公软件】Excel如何开n次方根

在文章&#xff1a;【分立元件】电阻的基础知识中我们学习电阻值、电阻值容差标注相关标准。知道了标准将电阻值标准数列化。因此电阻值并非1Ω、2Ω、3Ω那样的整数&#xff0c;而是2.2Ω、4.7Ω那样的小数。 这是因为电阻值以标准数(E系列)为准。系列的“E”是Exponent(指数)…

鸿蒙开发占多列的瀑布流

鸿蒙开发占多列的瀑布流 正常样式的瀑布流没什么好说&#xff0c;大家看下官方文档应该都写得来。关键是有些item要占多列&#xff0c;整行的效果 先看下效果图&#xff1a; 还有底部的效果图的&#xff0c;就不放了&#xff0c;你们应该也看得懂的 思路&#xff1a; 关键在…

libtorch---day04[MNIST数据集]

参考pytorch。 数据集读取 MNIST数据集是一个广泛使用的手写数字识别数据集&#xff0c;包含 60,000张训练图像和10,000张测试图像。每张图像是一个 28 28 28\times 28 2828像素的灰度图像&#xff0c;标签是一个 0到9之间的数字&#xff0c;表示图像中的手写数字。 MNIST …

使用Aqua进行WebUI测试(Pytest)——介绍篇(附汉化教程)

一、在创建时选择Selenium with Pytest 如果选择的是Selenium&#xff0c;则只能选择Java类语言 选择selenium with Pytest&#xff0c;则可以选择Python类语言 Environment 其中的【Environment】可选New 和 Existing New &#xff1a;选择这个选项意味着你希望工具为你创…

常用企业技术架构开发速查工具列表

对于Java开发者来说,不光要关注业务代码也要注重架构的修炼。日常用到的工具组件都是我们架构中重要的元素,服务于应用系统。我们应该选择适合应用体量的架构避免过度设计,最简单的方式就是矩阵方式去分析每个组件的适用场景优缺点,从而综合评估做好决策。 程序员大多数时间…

一次性说清楚,微软Mos国际认证

简介&#xff1a; Microsoft Office Specialist&#xff08;MOS&#xff09;中文称之为“微软办公软件国际认证”&#xff0c;是微软为全球所认可的Office软件国际性专业认证&#xff0c;全球有168个国家地区认可&#xff0c;每年有近百万人次参加考试&#xff0c;它能有效证明…

Elasticsearch集群架构

Elasticsearch是一种分布式搜索引擎&#xff0c;基于Apache Lucene构建&#xff0c;支持全文搜索、结构化搜索、分析和实时数据处理。 节点&#xff08;Node&#xff09; 节点是集群中的一台服务器。根据节点的角色&#xff0c;可以分为以下几种类型&#xff1a; 主节点&#…

uniapp中slot插槽用法

1.slot的用法 1.1 简单概念 元素作为组件模板之中的内容分发插槽&#xff0c;<slot> 元素自身将被替换 是不是这段话听着有点迷? 那么直接开始上代码 此时创建一个简单的页面&#xff0c;在中间写上一个<slot></slot>标签&#xff0c;标签内并没有数据 …

MySQL——隔离级别及解决方案

CRUD不加控制&#xff0c;会有什么问题&#xff1f; 比如上图场景&#xff0c;当我们的客户端A发现还有一张票的时候&#xff0c;将票卖掉&#xff0c;嗨还没有执行更新数据库的时候&#xff0c;客户端B又检查票数&#xff0c;发现票数大于0&#xff0c;又卖掉了一张票。然后客…

基于FPGA实现SD NAND FLASH的SPI协议读写

基于FPGA实现SD NAND FLASH的SPI协议读写 在此介绍的是使用FPGA实现SD NAND FLASH的读写操作&#xff0c;以雷龙发展提供的CS创世SD NAND FLASH样品为例&#xff0c;分别讲解电路连接、读写时序与仿真和实验结果。 目录 1 FLASH背景介绍 2 样品申请 3 电路结构与接口协议 …