大模型笔记之-XTuner微调个人小助手认知

news2025/1/13 17:30:51

前言

使用XTuner 微调个人小助手认知

一、下载模型

#安装魔搭依赖包
pip install modelscope
  1. 新建download.py内容如下
    其中Shanghai_AI_Laboratory/internlm2-chat-1_8b是魔搭对应的模型ID
    cache_dir='/home/aistudio/data/model’为指定下载到本地的目录
from modelscope import snapshot_download
model_dir = snapshot_download('Shanghai_AI_Laboratory/internlm2-chat-1_8b',cache_dir='/home/aistudio/data/model')

二、安装 XTuner

1.创建环境

#新建一个code文件夹
mkdir -p /home/aistudio/data/code
#切换到该目录下
cd /home/aistudio/data/code
#从 Github 上下载源码
git clone -b v0.1.21  https://github.com/InternLM/XTuner
#进入源码目录
cd XTuner
# 执行安装
pip install -e '.[deepspeed]'

2.结果验证

xtuner version

三. 快速开始

这里我们用 internlm2-chat-1_8b 模型,通过 QLoRA 的方式来微调一个自己的小助手认知作为案例来进行演示

1.准备数据

#新建datas文件夹
mkdir -p datas
#创建json文件
touch datas/assistant.json

2.数据生成

1.新建一个xtuner_generate_assistant.py内容如下
2.修改neme由“伍鲜同志”改为“阿豪”
3.修改数据写入路径为刚刚创建的json文件

import json

# 设置用户的名字
name = '阿豪'
# 设置需要重复添加的数据次数
n = 8000

# 初始化数据
data = [
    {"conversation": [{"input": "请介绍一下你自己", "output": "我是{}的小助手,内在是上海AI实验室书生·浦语的1.8B大模型哦".format(name)}]},
    {"conversation": [{"input": "你在实战营做什么", "output": "我在这里帮助{}完成XTuner微调个人小助手的任务".format(name)}]}
]

# 通过循环,将初始化的对话数据重复添加到data列表中
for i in range(n):
    data.append(data[0])
    data.append(data[1])

# 将data列表中的数据写入到'datas/assistant.json'文件中
with open('datas/assistant.json', 'w', encoding='utf-8') as f:
    # 使用json.dump方法将数据以JSON格式写入文件
    # ensure_ascii=False 确保中文字符正常显示
    # indent=4 使得文件内容格式化,便于阅读
    json.dump(data, f, ensure_ascii=False, indent=4)

3.初始化数据

#执行
python xtuner_generate_assistant.py 

在这里插入图片描述

4.获取训练脚本

xtuner copy-cfg internlm2_chat_1_8b_qlora_alpaca_e3 .

修改内容如下

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
                                 VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.parallel.sequence import SequenceParallelSampler
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
# Model
pretrained_model_name_or_path = '/mnt/workspace/model/Shanghai_AI_Laboratory/internlm2-chat-1_8b'
use_varlen_attn = False

# Data
alpaca_en_path = '/mnt/workspace/code/datas/assistant.json'
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 2048
pack_to_max_length = True

# parallel
sequence_parallel_size = 1

# Scheduler & Optimizer
batch_size = 1  # per_device
accumulative_counts = 16
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1  # grad clip
warmup_ratio = 0.03

# Save
save_steps = 500
save_total_limit = 2  # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = [
    '请介绍一下你自己', 'Please introduce yourself'
]

#######################################################################
#                      PART 2  Model & Tokenizer                      #
#######################################################################
tokenizer = dict(
    type=AutoTokenizer.from_pretrained,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    trust_remote_code=True,
    padding_side='right')

model = dict(
    type=SupervisedFinetune,
    use_varlen_attn=use_varlen_attn,
    llm=dict(
        type=AutoModelForCausalLM.from_pretrained,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        quantization_config=dict(
            type=BitsAndBytesConfig,
            load_in_4bit=True,
            load_in_8bit=False,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4')),
    lora=dict(
        type=LoraConfig,
        r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        bias='none',
        task_type='CAUSAL_LM'))

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
alpaca_en = dict(
    type=process_hf_dataset,
    dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=None,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length,
    use_varlen_attn=use_varlen_attn)

sampler = SequenceParallelSampler \
    if sequence_parallel_size > 1 else DefaultSampler
train_dataloader = dict(
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
    dataset=alpaca_en,
    sampler=dict(type=sampler, shuffle=True),
    collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
#                    PART 4  Scheduler & Optimizer                    #
#######################################################################
# optimizer
optim_wrapper = dict(
    type=AmpOptimWrapper,
    optimizer=dict(
        type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
    clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
    accumulative_counts=accumulative_counts,
    loss_scale='dynamic',
    dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
param_scheduler = [
    dict(
        type=LinearLR,
        start_factor=1e-5,
        by_epoch=True,
        begin=0,
        end=warmup_ratio * max_epochs,
        convert_to_iter_based=True),
    dict(
        type=CosineAnnealingLR,
        eta_min=0.0,
        by_epoch=True,
        begin=warmup_ratio * max_epochs,
        end=max_epochs,
        convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
#                           PART 5  Runtime                           #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
    dict(type=DatasetInfoHook, tokenizer=tokenizer),
    dict(
        type=EvaluateChatHook,
        tokenizer=tokenizer,
        every_n_iters=evaluation_freq,
        evaluation_inputs=evaluation_inputs,
        system=SYSTEM,
        prompt_template=prompt_template)
]

if use_varlen_attn:
    custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type=IterTimerHook),
    # print log every 10 iterations.
    logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
    # enable the parameter scheduler.
    param_scheduler=dict(type=ParamSchedulerHook),
    # save checkpoint per `save_steps`.
    checkpoint=dict(
        type=CheckpointHook,
        by_epoch=False,
        interval=save_steps,
        max_keep_ckpts=save_total_limit),
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)

5.开启训练

xtuner train ./internlm2_chat_1_8b_qlora_alpaca_e3_copy.py

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

微调前

在这里插入图片描述
在这里插入图片描述

6. 模型格式转换

pth_file=`ls -t ./work_dirs/internlm2_chat_1_8b_qlora_alpaca_e3_copy/*.pth | head -n 1`
export MKL_SERVICE_FORCE_INTEL=1
export MKL_THREADING_LAYER=GNU
xtuner convert pth_to_hf ./internlm2_chat_1_8b_qlora_alpaca_e3_copy.py ${pth_file} ./hf

在这里插入图片描述

7.模型合并

export MKL_SERVICE_FORCE_INTEL=1
export MKL_THREADING_LAYER=GNU
xtuner convert merge /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b ./hf ./merged --max-shard-size 2GB

在这里插入图片描述

8.测试效果

python -m streamlit run xtuner_streamlit_demo.py 

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2064432.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Stable Diffusion的微调方法原理总结

目录 1、Textural Inversion(简易) 2、DreamBooth(完整) 3、LoRA(灵巧) 4、ControlNet(彻底) 5、其他 1、Textural Inversion(简易) 不改变网络结构&…

Ciallo~(∠・ω・ )⌒☆第二十五篇 Redis

Redis 是一个高性能的键值存储数据库,它能够在内存中快速读写数据,并且支持持久化到磁盘。它被广泛应用于缓存、队列、实时分析等场景。 一、启动redis服务器 要打开redis服务器,需要在终端中输入redis-server命令。确保已经安装了redis&…

【Java】/* 链式队列 和 循环队列 - 底层实现 */

一、链式队列 1. 使用双向链表实现队列,可以采用尾入,头出 也可以采用 头入、尾出 (LinkedList采用尾入、头出) 2. 下面代码实现的是尾入、头出: package bageight;/*** Created with IntelliJ IDEA.* Description:* User: tangyuxiu* Date: …

mOTA v2.0

mOTA v2.0 一、简介 本开源工程是一款专为 32 位 MCU 开发的 OTA 组件,组件包含了 bootloader 、固件打包器 (Firmware_Packager) 、固件发送器 三部分,并提供了基于多款 MCU (STM32F1 / STM32F407 / STM32F411 / STM32L4) 和 YModem-1K 协议的案例。基…

【文献及模型、制图分享】2000—2020年中国青饲料播种面积及供需驱动因素的时空格局

文献介绍 高产、优质的青饲料对于国家畜牧业发展和食物供给至关重要。然而,当前对于青饲料播种面积时空变化格局及其阶段性特征、区域差异以及影响因素等尚未清楚。 本文基于省级面板数据分析了2000—2020年青饲料种植的时空格局变化,结合MODIS-NPP产品…

Nginx 405 not allowed

问题原因:nginx不允许静态文件被post请求 解决:添加error_page 405 200 $request_uri;

白酒与家庭:团圆时刻的需备佳品

在中国传统文化中,家庭是社会的基石,是每个人心灵的港湾。而团圆,则是家庭生活中较美好的时刻。在这样一个特殊的日子里,白酒,尤其是豪迈白酒(HOMANLISM),成为了团圆时刻的需备佳品。…

了解JS数组元素及属性

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1、定义数组并输出2、查询数组的长度3、访问数组的第一个元素4、访问数组中第一个元素的xxx属性5、从数组元素中提取ID并存储到搜索参数对象 提示:以下是…

C++设计模式1:单例模式(懒汉模式和饿汉模式,以及多线程问题处理)

饿汉单例模式 程序还没有主动获取实例对象&#xff0c;该对象就产生了&#xff0c;也就是程序刚开始运行&#xff0c;这个对象就已经初始化了。 class Singleton { public:~Singleton(){std::cout << "~Singleton()" << std::endl;}static Singleton* …

KUKA KR C2 中文操作指南 详情见目录

KUKA KR C2 中文操作指南 详情见目录

Selenium + Python 自动化测试22(PO+数据驱动)

我们的目标是&#xff1a;按照这一套资料学习下来&#xff0c;大家可以独立完成自动化测试的任务。 上一篇我们讨论了PO模式和unittest框架结合起来使用。 本篇文章我们综合一下之前学习的内容&#xff0c;如先将PO模式、数据驱动思想和我们生成HTML报告融合起来&#xff0c;综…

​2024年AI新蓝海:三门生意如何借AI之力,开启变现新篇章

【导语】在这个日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;已不再是遥不可及的未来科技&#xff0c;而是正逐步渗透到我们生活的方方面面&#xff0c;成为推动产业升级的重要力量。你是否还在为传统行业的未来而忧虑&#xff1f;别担心&#xff0c;AI正以其…

Pandas DataFrame 数据转换处理和多条件查询

工作中需要处理一个比较大的数据&#xff0c;且当中需要分析的日期类型字段为字符串型&#xff0c;需要进行转换&#xff0c;获得一个新的字段用于时间统计。我们应用 datetime.datetime.strptime 函数进行转换。 数据读取与时间列补充代码如下&#xff1a; import pandas as…

原来ChatGPT是这么评价《黑神话:悟空》的啊?

《黑神话&#xff1a;悟空》一经上线便迅速吸引了全球的目光&#xff0c;成为了今日微博热搜榜上的焦点话题。作为中国首款现象级的中国3A大作&#xff0c;它的发布无疑引发了广泛的关注与讨论。 《黑神话&#xff1a;悟空》&#xff0c;这款3A国产游戏大作&#xff0c;由国内游…

根据状态的不同,显示不同的背景颜色

文章目录 前言HTML模板部分JavaScript部分注意&#xff1a;主要差异影响如何处理示例 总结 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 实现效果&#xff1a; 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 根据给定的状态…

文件操作2(函数的专栏)

1、文件的打开和关闭 1.1文件指针 在缓冲文件系统中&#xff0c;关键的概念是“文件类型指针”&#xff0c;简称“文件指针”取名为FILE。 例如&#xff0c; VS2013编译环境提供的 stdio. h头文件中有以下的文件类型申明&#xff1a; struct _ iobuf { char *_ ptr; int _…

【YOLO5 项目实战】(6)YOLO5+StrongSORT 目标追踪

欢迎关注『youcans动手学模型』系列 本专栏内容和资源同步到 GitHub/youcans 【YOLO5 项目实战】&#xff08;1&#xff09;YOLO5 环境配置与检测 【YOLO5 项目实战】&#xff08;2&#xff09;使用自己的数据集训练目标检测模型 【YOLO5 项目实战】&#xff08;6&#xff09;Y…

数据库机器上停service360safe

发现有个数据库的负载较高&#xff0c;发现有360safe&#xff0c;就准备停了该服务再观察 [rootdb1 ~]# ps -ef |grep 360 root 970 1 0 15:12 ? 00:00:10 /opt/360safe/360entclient root 976 970 5 15:12 ? 00:18:42 /opt/360…

Linux之RabbitMQ集群部署

RabbitMQ 消息中间件 1、消息中间件 消息(message)&#xff1a; 指在服务之间传送的数据。可以是简单的文本消息&#xff0c;也可以是包含复杂的嵌入对象的消息 消息队列(message queue): 指用来存放消息的队列&#xff0c;一般采用先进先出的队列方式&#xff0c;即最先进入的…

关于springboot的异常处理以及源码分析(一)

一、什么是异常处理 1、文档定义 首先我们先来看springboot官方对于异常处理的定义。springboot异常处理 在文档的描述中&#xff0c;我们首先可以看到的一个介绍如下&#xff1a; By default, Spring Boot provides an /error mapping that handles all errors in a sensib…