深度学习:从零开始的DeepSeek-R1-Distill有监督微调训练实战(SFT)

news2025/4/21 13:55:42

原文链接:从零开始的DeepSeek微调训练实战(SFT)

微调参考示例:由unsloth官方提供https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynbhttps://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb

本文使用modelscope社区提供的免费GPU示例进行复现。

魔搭社区汇聚各领域最先进的机器学习模型,提供模型探索体验、推理、训练、部署和应用的一站式服务。https://www.modelscope.cn/my/overview

基础概念

预训练模型 (Pre-trained Model): 预训练模型是指在大规模数据集上(如Wikipedia、书籍、网页等)进行过训练的模型。这些模型学习到了通用的语言知识和模式。你可以把它们想象成已经掌握了基本语法和常识的“学生”。常见的预训练模型有BERT、GPT、Llama、DeepSeek等。

微调 (Fine-tuning): 微调是指在预训练模型的基础上,使用特定任务的数据集继续训练模型,使其适应特定任务或领域。就像让一个已经掌握基本知识的学生,学习特定专业的知识。

为什么需要微调?: 预训练模型虽然强大,但它们是通用的。对于特定任务(如医疗问答、代码生成、情感分析等),预训练模型可能表现不佳。微调可以让模型更好地适应特定任务,提高性能。

SFT (Supervised Fine-Tuning): SFT是一种微调方法,它使用带有标签的数据集进行训练。例如,在医疗问答任务中,数据集会包含问题和对应的正确答案。模型通过学习这些问题和答案之间的关系,来提高在特定任务上的表现。

SFT vs. RLHF:
- SFT (Supervised Fine-tuning): 使用标注好的数据集进行训练。模型学习输入和输出之间的直接映射。简单高效,但依赖于高质量的标注数据。
- RLHF (Reinforcement Learning from Human Feedback): 通过人类反馈来训练模型。首先使用SFT,然后通过人类对模型输出进行打分,并使用强化学习算法来优化模型。可以更好地捕捉人类偏好,但更复杂,成本更高。
- 总结: SFT是基础,RLHF是进阶。通常先进行SFT,再根据需要进行RLHF。

高效微调 (Efficient Fine-tuning): 高效微调是指在有限的计算资源下,对大型模型进行微调的方法。例如,LoRA(Low-Rank Adaptation)只微调模型中的部分参数,从而减少计算量和内存需求。

环境准备

unsloth

  • Unsloth 是什么?

    Unsloth 是一个专为大型语言模型(LLM)微调和推理设计的框架。它的主要目标是提高训练速度和降低内存消耗,让用户能够在有限的硬件资源上更高效地进行 LLM 的操作。

  • Unsloth 的主要特点

    • 速度快:Unsloth 通过各种优化技术(如 Flash Attention、量化等)显著提高了 LLM 的训练和推理速度。在某些情况下,速度提升可达数倍。

    • 内存占用低:Unsloth 通过优化内存使用,使得在较小显存的 GPU 上也能微调大型模型。

    • 易于使用:Unsloth 提供了简洁的 API,方便用户快速上手。

    • 支持多种模型:Unsloth 支持多种流行的 LLM,如 Llama、Mistral、Phi、Qwen 等。

  • Unsloth 的安装

    • 直接使用pip命令安装即可

    • pip install unsloth
      pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

WandB (Weights & Biases) 安装 

  • WandB 是什么?

    WandB 是一个用于机器学习实验跟踪、可视化和协作的平台。它可以帮助你记录实验的各种指标、超参数、模型权重、数据集等,并提供交互式的可视化界面,方便你分析实验结果和比较不同实验的表现。

  • WandB 的主要特点

    • 实验跟踪:记录实验的各种指标(如损失、准确率、学习率等)、超参数、代码版本、数据集等。

    • 可视化:提供交互式的图表,方便你分析实验结果。

    • 协作:支持多人协作,方便团队成员共享实验结果和讨论。

    • 超参数优化:支持自动超参数搜索,帮助你找到最佳的超参数组合。

    • 模型管理:可以保存和版本控制模型权重。

    • 报告生成:可以自动生成实验报告。

  • WandB 的安装和注册

    • 安装:使用 pip 安装 WandB:

    • pip install wandb
    • 注册:
    • 访问 WandB 官网(https://wandb.ai/site)并注册账号。

    • 注册后,在你的个人设置页面找到 API Key,复制它。

    • WandB 在环境准备中的作用

      在 SFT 环境准备中,WandB 主要用于:

      • 监控训练过程:在训练过程中,WandB 会自动记录各种指标,如损失、学习率等,并提供实时更新的图表。

      • 记录超参数:WandB 会记录你使用的超参数,方便你后续复现实验和比较不同超参数的效果。

      • 保存模型:你可以使用 WandB 保存训练过程中的模型权重,方便后续加载和使用。

      • 分析实验结果:WandB 提供了丰富的可视化工具,可以帮助你分析实验结果,找出最佳的模型和超参数。

  • 模型下载

    ModelScope模型地址:https://www.modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B

       创建DeepSeek-R1-Distill-Qwen-7B文件夹,用于保存下载的模型权重:

  • mkdir ./DeepSeek-R1-Distill-Qwen-7B
    

       创建成功后,可使用如下命令下载模型:

  • modelscope download --model deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --local_dir ./DeepSeek-R1-Distill-Qwen-7B
  • 模型权重文件  
  • config.json
  • 内容: 这个文件包含了模型的配置信息,它是一个 JSON 格式的文本文件。这些配置信息定义了模型的架构、层数、隐藏层大小、注意力头数等。
  • 重要性: 这是模型的核心配置文件,加载模型时会读取这个文件来构建模型的结构。
{
  "architectures": [
    "Qwen2ForCausalLM"  // 指定模型的架构类型为 Qwen2ForCausalLM,这是一个用于因果语言建模(生成文本)的 Qwen2 模型。
  ],
  "attention_dropout": 0.0,  // 在注意力机制中使用的 dropout 比率。设置为 0.0 表示不使用 dropout。Dropout 是一种正则化技术,用于防止过拟合。
  "bos_token_id": 151646,  // 句子开头标记(Beginning of Sentence)的 ID。在分词器中,每个词或标记都有一个唯一的 ID。
  "eos_token_id": 151643,  // 句子结束标记(End of Sentence)的 ID。
  "hidden_act": "silu",  // 隐藏层的激活函数。SiLU(Sigmoid Linear Unit)是一种激活函数。
  "hidden_size": 3584,  // 隐藏层的大小(维度)。
  "initializer_range": 0.02,  // 用于初始化模型权重的标准差。
  "intermediate_size": 18944,  // 前馈网络(Feed-Forward Network)中间层的大小。
  "max_position_embeddings": 131072,  // 模型可以处理的最大序列长度(位置嵌入的数量)。
  "model_type": "qwen2",  // 模型类型为 qwen2。
  "num_attention_heads": 28,  // 注意力机制中注意力头的数量。
  "num_hidden_layers": 28,  // 模型中隐藏层(Transformer 层)的数量。
  "num_key_value_heads": 4,  //  键值头的数量。用于分组查询注意力(Grouped-Query Attention, GQA)。如果该值小于`num_attention_heads`,则表示启用了GQA, 否则为多头注意力(Multi-Head Attention, MHA)。
  "rms_norm_eps": 1e-06,  // RMSNorm(Root Mean Square Layer Normalization)中使用的 epsilon 值,用于防止除以零。
  "rope_theta": 10000.0,  // RoPE(Rotary Positional Embeddings)中使用的 theta 值。RoPE 是一种位置编码方法。
  "tie_word_embeddings": false,  // 是否将词嵌入矩阵和输出层的权重矩阵绑定(共享)。设置为 `false` 表示不绑定。
  "torch_dtype": "bfloat16",  // 模型使用的默认数据类型。`bfloat16` 是一种 16 位浮点数格式,可以提高计算效率并减少内存占用。
  "transformers_version": "4.48.3",  // 使用的 Transformers 库的版本。
  "use_cache": true,  // 是否使用缓存机制来加速推理。设置为 `true` 表示使用缓存。
  "vocab_size": 152064  // 词汇表的大小(不同词或标记的数量)。
}
  • configuration.json 
  • 内容: 这个文件和 config.json 类似,通常包含模型的配置信息。在某些模型中,这两个文件可能是同一个文件,或者 configuration.json 包含了更详细的配置。对于 DeepSeek-R1-Distill-7B 模型,你可以认为它和 config.json 作用相同。
  • generation_config.json
  • 内容: 这个文件包含模型生成文本时的配置参数,例如解码方法(beam search、top-k sampling 等)、最大生成长度、温度系数等。
{
  "_from_model_config": true,  // 表示这些配置中的大部分是从模型的配置文件(config.json)中继承的。
  "bos_token_id": 151646,    // 句子开头标记(Beginning of Sentence)的 ID。
  "eos_token_id": 151643,    // 句子结束标记(End of Sentence)的 ID。
  "do_sample": true,       // 是否使用采样(sampling)方法生成文本。如果设置为 `false`,则使用贪婪解码(greedy decoding)。
  "temperature": 0.6,      // 温度系数。温度系数用于控制生成文本的随机性。值越高,生成的文本越随机;值越低,生成的文本越确定。
  "top_p": 0.95,          // Top-p 采样(nucleus sampling)的阈值。Top-p 采样只从概率最高的、累积概率超过 `top_p` 的词中进行采样。
  "transformers_version": "4.39.3"  // 使用的 Transformers 库的版本。(原文档中是4.39.3,这与之前config.json里的版本号不同,但通常情况下,版本号应当以config.json里的为准)
}
  • LICENSE

    内容: 这是一个文本文件,包含了模型的许可证信息。许可证规定了你可以如何使用、修改和分发模型。
  • model-00001-of-00002.safetensors 和 model-00002-of-00002.safetensors

    内容: 这些文件是模型权重文件,它们以 Safetensors 格式存储。Safetensors 是一种安全且高效的张量存储格式。由于模型很大,权重被分成了多个文件。
  • model.safetensors.index.json

    内容: 这是一个索引文件,用于指示哪些权重存储在哪个 .safetensors 文件中。当模型权重被分成多个文件时,需要这个索引文件来正确加载权重。
  • README.md

    内容: 这是一个 Markdown 格式的文本文件,通常包含模型的介绍、使用说明、示例代码等。
  • tokenizer_config.json

    内容: 包含分词器(Tokenizer)的配置信息。
{
  "add_bos_token": true,  // 是否在输入序列的开头添加句子开头标记(BOS token)。设置为 `true` 表示添加。
  "add_eos_token": false, // 是否在输入序列的结尾添加句子结束标记(EOS token)。设置为 `false` 表示不添加。
  "__type": "AddedToken",           //这是一个内部使用的类型标记,表示这是一个"添加的token"
  "content": "<|begin of sentence|>",  // 这是一个特殊标记的内容,表示句子的开始
  "lstrip": false,                    //在处理这个token时,是否移除左侧的空白符
  "normalized": true,                 // 是否对这个token进行标准化处理
  "rstrip": false,                     //是否移除右边的空白符
  "single_word": false,                // 是否将此token视为单个词
  "clean_up_tokenization_spaces": false, // 是否清理分词过程中的空格。设置为 `false` 表示不清理。
    "__type": "AddedToken",
  "content": "<|end of sentence|>",
  "lstrip": false,
  "normalized": true,
  "rstrip": false,
  "single_word": false,
  "legacy": true,                    // 是否使用旧版(legacy)的分词器行为。这里设置为`true`可能表示兼容旧版本。
  "model_max_length": 16384,      // 模型可以处理的最大序列长度。
    "__type": "AddedToken",
  "content": "<|end of sentence|>",
  "lstrip": false,
  "normalized": true,
  "rstrip": false,
  "single_word": false,
  "sp_model_kwargs": {},          // SentencePiece 模型的相关参数(这里为空)。
  "unk_token": null,            // 未知词标记(unknown token)。设置为 `null` 表示没有专门的未知词标记。
  "tokenizer_class": "LlamaTokenizerFast",  // 分词器的类名。`LlamaTokenizerFast` 表示这是一个快速版本的 Llama 分词器。
  "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool calls begin|><|tool call begin|>' + tool['type'] + '<|tool sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool call end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool call begin|>' + tool['type'] + '<|tool sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool call end|>'}}{{'<|tool calls end|><|end of sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool outputs end|>' + message['content'] + '<|end of sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end of sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool outputs begin|><|tool output begin|>' + message['content'] + '<|tool output end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool output begin|>' + message['content'] + '<|tool output end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool outputs end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\n'}}{% endif %}"
  // ↑这是一个 Jinja2 模板,定义了对话的格式。它根据消息的角色(用户、助手、工具)和内容,构建最终的输入文本。

}

数据集准备

推理模型与通用模型相比,输出的回答包括了一段思考过程(Chain of Thoughts 思维链)。这个思考过程本质也是通过预测下一个token进行实现的,只不过DeepSeek系列模型输出时,会将思考过程放在一对特殊token <think>...</think>之间,</think>标签后的内容作为回答的正文。

微调推理模型时,同样需要包含思维链和最终回答两部份。因此,在围绕DeepSeek R1 Distill模型组进行微调的时候,微调数据集的回复部分文本也需要是包含推理 和最终回复两部分内容,才能使得DeepSeek R1模型组在保持既定回复风格的同时,强化模型能力,反之则会导致指令消融问题(模型回复不再包含think部分)。

modelscope社区提供了多样的推理数据集供开发者使用 。

原文采取由深圳大数据研究院发布的HuatuoGPT-o1模型的微调数据集—medical-o1-reasoning-SFT,地址:https://www.modelscope.cn/datasets/AI-ModelScope/medical-o1-reasoning-SFT。

本数据集将数据分为:

- Question:医疗问题

- Complex_CoT:进行诊疗的思维链

- Response:最终的答复

数据集中所有内容均为英文

模型演示

在进行微调前,我们可以了解一下模型的基本用法。

1. 加载已经下载到本地的模型

max_seq_length = 2048 # 指定输出的最大长度
dtype = None # 不指定模型精度,由unsloth框架自动检测
load_in_4bit = False # 采用int4量化,减少显存占用,但是会降低模型性能

# 加载模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="./DeepSeek-R1-Distill-Qwen-7B", # 待微调的模型名称
    max_seq_length=max_seq_length, # 模型可以处理的最长序列长度
    dtype=dtype, # 限定模型浮点精度
    load_in_4bit=False # 是否使用int量化
)

2. 通过unsloth框架配置待微调的LoRA模型

'''
LoRA 的核心思想是,对于预训练模型的权重矩阵 W,不直接对其进行更新,
而是添加一个低秩分解矩阵 ΔW = A * B,
其中 A 和 B 是两个较小的矩阵。在微调过程中,只更新 A 和 B 的参数,而 W 的参数保持不变。
这样可以大大减少需要微调的参数数量,降低计算成本。
'''
model = FastLanguageModel.get_peft_model(
    model,
    r=8, # lora微调的秩  # 较小的 `r` 值会减少需要微调的参数数量,降低计算成本,但也可能降低模型的表达能力。# 较大的 `r` 值会增加参数数量,提高模型的表达能力,但也会增加计算成本。
         # 通常需要根据实际情况进行实验,选择合适的 `r` 值。一般来说,8、16、32、64 是常用的值。
    target_modules = ["q_proj", "k_proj", "v_proj", # 指定要应用 LoRA 的模块。这些模块通常是 Transformer 模型中的线性层。
                     "o_proj", "gate_proj", "up_proj", "down_proj"], # 这里分别应用了注意力机制中的Wq, Wk, Wv, Wo线性投影层,FFN中的线性层
    lora_alpha=8, # lora缩放因子,决定模型权重的更新程度,建议设置为r或r的倍数
    lora_dropout=0,
    bias="none", # 不为LoRA层添加偏置
    use_gradient_checkpointing="unsloth", # 是否设置梯度检查点,# 梯度检查点是一种以时间换空间的技术,可以减少内存占用,但会增加计算时间。
    random_state=3407, # 设置随机种子,保证实验可以浮现
    use_rslora=False, # 是否使用Rank-Stabilized LoRA(rslora)。rslora 是一种改进的 LoRA 方法,可以自动调整 `lora_alpha`。
    loftq_config=None # 是否使用QLoRA,即将LoRA与量化技术结合
)

3. 进行简单推理

# 将模型切换为推理模式,可以进行简单的对话
FastLanguageModel.for_inference(model)

question = "请介绍一下你自己!"
# 对输入进行分词
# 传入待分词的文本列表,最后返回一个PyTorch张量
input_ids = tokenizer([question], return_tensors="pt").to("cuda")
# input_ids返回token对应词表中的id,即将一个句子映射为一个token id序列
# attention_mask用于表示input_ids中哪些是为了填充序列长度而通过<pad>填充的token,1表示所有的 token 都是实际的词或标记,没有填充。
input_ids

# 调用模型生成答复
'''
是否使用缓存机制来加速生成过程。

设置为 True 表示使用缓存。

缓存机制会存储先前计算的键/值对(key/value pairs),避免重复计算,从而提高生成速度。

在自回归生成(逐个 token 生成)中,缓存机制非常有用。
'''
outputs_ids = model.generate(
    input_ids=input_ids.input_ids,
    max_new_tokens=1024,
    use_cache=True
)
# 模型的直接输出同样为token ids,需要通过tokenizer进行解码
outputs_ids

response = tokenizer.batch_decode(outputs_ids)
print(response[0])

可以在prompt中添加<think>标签对引导模型进行思考。

question = "你好,好久不见!"
# 更完善的prompt
prompt_style_chat = """请写出一个恰当的回答来完成当前对话任务。

### Instruction:
你是一名助人为乐的助手。

### Question:
{}

### Response:
"""

# 使用tokenizer处理prompt
input_ids = tokenizer([prompt_style_chat.format(questionm '')], return_tensors="pt").to("cuda")

outputs = model.generate(
    input_ids=input_ids.input_ids,
    use_cache=True,
    do_sample=True,          # 启用采样
    temperature=0.7,         # 较高的温度
    top_p=0.9,               # Top-p 采样
    repetition_penalty=1.2,  # 重复惩罚
    max_new_tokens=1024,     # 最大新token数量
)

response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(response)

未添加<think>标签,模型有概率不思考。

prompt_style_chat = """请写出一个恰当的回答来完成当前对话任务。

### Instruction:
你是一名助人为乐的助手。

### Question:
{}

### Response:
<think>{}
"""
question = "请你分析李朗笛和朗朗以及李云迪之间的关系"

# 使用tokenizer处理prompt
input_ids = tokenizer([prompt_style_chat.format(question, "")], return_tensors="pt").to("cuda")

outputs = model.generate(
    input_ids=input_ids.input_ids,
    use_cache=True,
    do_sample=True,          # 启用采样
    temperature=0.7,         # 较高的温度
    top_p=0.95,               # Top-p 采样
    repetition_penalty=1.2,  # 重复惩罚
    max_new_tokens=1024,     # 最大新token数量
)

response = tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]
print(response)

 添加<think>标签作为引导,模型更容易进行思考。

微调实操

微调请重新开一个notebook,清空缓存,从头进行。

1. 倒入依赖

# 导入依赖
from modelscope.msdatasets import MsDataset # modelscope数据集类
from trl import SFTTrainer # 微调训练器配置类
from transformers import TrainingArguments # 微调参数配置类
from unsloth import FastLanguageModel, is_bfloat16_supported # 检查GPU是否支持bf16
import wandb # 微调数据可视化

 2. 定义模板

因为数据集是英文的,所以promt也采用英文,保证语言一致性。

# 1. 定义prompt模板
finetune_template = '''Below is an instruction that describes a task, paired with an input that provides further context. 
Write a response that appropriately completes the request. 
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.

### Instruction:
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
Please answer the following medical question. 

### Question:
{}

### Response:
<think>
{}
</think>
{}'''

prompt_style_zh = '''以下是一个任务说明,配有提供更多背景信息的输入。
请写出一个恰当的回答来完成该任务。
在回答之前,请仔细思考问题,并按步骤进行推理,确保回答逻辑清晰且准确。

### Instruction:
您是一位具有高级临床推理、诊断和治疗规划知识的医学专家。
请回答以下医学问题。


### 问题:
{}

### 问题:
<think>{}'''

3. 加载模型

# 2. 加载模型
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="./DeepSeek-R1-Distill-Qwen-7B",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=False
)
EOS_TOKEN = tokenizer.eos_token

4. 加载数据集

# 在模型微调时,给微调数据集加上 EOS_TOKEN 非常重要。它可以明确文本边界、保持训练目标一致性、控制生成过程、处理多轮对话,以及更好地利用 CoT 数据集。
EOS_TOKEN = tokenizer.eos_token

# 格式话训练数据
def formatting_prompts_func(examples):
    inputs = examples["Question"]
    cots = examples["Complex_CoT"]
    outputs = examples["Response"]
    texts = []
    for input, cot, output in zip(inputs, cots, outputs):
        text = finetune_template.format(input, cot, output) + EOS_TOKEN
        texts.append(text)
    return {
        "text": texts,
    }

ds = MsDataset.load('AI-ModelScope/medical-o1-reasoning-SFT', split = "train")
dataset = ds.map(formatting_prompts_func, batched = True,)
print(dataset["text"][0])

5. 配置微调模型

将LoRA模块加入模型,为微调做准备

'''
LoRA 的核心思想是,对于预训练模型的权重矩阵 W,不直接对其进行更新,
而是添加一个低秩分解矩阵 ΔW = A * B,
其中 A 和 B 是两个较小的矩阵。在微调过程中,只更新 A 和 B 的参数,而 W 的参数保持不变。
这样可以大大减少需要微调的参数数量,降低计算成本。
'''
model = FastLanguageModel.get_peft_model(
    model,
    r=16, # lora微调的秩  # 较小的 `r` 值会减少需要微调的参数数量,降低计算成本,但也可能降低模型的表达能力。# 较大的 `r` 值会增加参数数量,提高模型的表达能力,但也会增加计算成本。
         # 通常需要根据实际情况进行实验,选择合适的 `r` 值。一般来说,8、16、32、64 是常用的值。
    target_modules = ["q_proj", "k_proj", "v_proj", # 指定要应用 LoRA 的模块。这些模块通常是 Transformer 模型中的线性层。
                     "o_proj", "gate_proj", "up_proj", "down_proj"], # 这里分别应用了注意力机制中的Wq, Wk, Wv, Wo线性投影层,FFN中的线性层
    lora_alpha=16, # lora缩放因子,决定模型权重的更新程度,建议设置为r或r的倍数
    lora_dropout=0,
    bias="none", # 不为LoRA层添加偏置
    use_gradient_checkpointing="unsloth", # 是否设置梯度检查点,# 梯度检查点是一种以时间换空间的技术,可以减少内存占用,但会增加计算时间。
    random_state=3407, # 设置随机种子,保证实验可以浮现
    use_rslora=False, # 是否使用Rank-Stabilized LoRA(rslora)。rslora 是一种改进的 LoRA 方法,可以自动调整 `lora_alpha`。
    loftq_config=None # 是否使用QLoRA,即将LoRA与量化技术结合
)

6. 配置微调参数

# 5. 配置微调参数
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text", # 数据集中包含文本的字段的名称。
    # dataset_text_field="text", # 说明text列对应的是微调数据集
    max_seq_length=2048, # 模型能处理的最长序列
    dataset_num_proc=2, # 用于预处理数据的进程数。
    args=TrainingArguments(
        per_device_train_batch_size=2, # mini-batch-size
        gradient_accumulation_steps=4, # 梯度累积,用于模型batch_size=2*4=8的情况,模型实际上经过 2 * 4 = 8 个batch之后才会更新参数(一个step),能缓解GPU无法放下大batch的问题
        num_train_epochs=3, # 训练轮数
        # max_steps = 60 # 如果要进行迅速严重微调可行性,可以只训练60个steps,训练的总步数(参数更新次数)。
        warmup_steps=5, # 模型热身步数,学习率会从 0 逐渐增加到设定的学习率。
        lr_scheduler_type="linear", # 学习率调度器类型。这里使用线性调度器,学习率会线性下降。
        learning_rate=2e-4, # 学习率
        fp16=not is_bfloat16_supported(), # 是否使用 FP16(16 位浮点数)混合精度训练。如果 GPU 不支持 bfloat16,则使用 fp16。
        bf16=is_bfloat16_supported(),
        logging_steps=10, # 多少个step打印一次信息
        optim="adamw_8bit", # 指定优化器
        weight_decay=0.01, # 权重衰退
        seed=3407, # 随机种子,保证结果可以复现
        output_dir="outputs" # 保存训练结果(模型、日志等)的目录。
    )
)

7. 进行微调

# 6. 进行微调
wandb.init()
trainer_stats = trainer.train()

看到如下输出即表示微调正在运行中。 

8. 将LoRA权重与原始矩阵合并,保存微调后的模型

# LoRA微调完成后,保存微调模型并合并矩阵
new_model_local = "DeepSeek-R1-Qwen-7B-Medical-Full"  # 定义一个字符串变量,表示保存模型的本地路径。
model.save_pretrained(new_model_local)          # 保存微调后的模型(包括 LoRA 权重)。
tokenizer.save_pretrained(new_model_local)      # 保存分词器。
model.save_pretrained_merged(new_model_local, tokenizer, save_method="merged_16bit")  # 合并 LoRA 权重到基础模型中,并保存合并后的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2319068.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AI News | 20250320】每日AI进展

AI Repos 1、servers 该仓库提供详细入门指南&#xff0c;用户可通过简单步骤连接Claude客户端&#xff0c;快速使用所有服务器功能。此项目由Anthropic管理&#xff0c;展示MCP的多样性与扩展性&#xff0c;助力开发者为大语言模型提供安全、可控的工具与数据访问。 2、awe…

从零开始实现 C++ TinyWebServer 阻塞队列 BlockQueue类详解

文章目录 阻塞队列是什么&#xff1f;为什么需要阻塞队列&#xff1f;BlockQueue 成员变量实现 push() 函数实现 pop() 函数实现 close() 函数BlockQueue 代码BlockQueue 测试 从零开始实现 C TinyWebServer 项目总览 项目源码 阻塞队列是什么&#xff1f; 阻塞队列是一种线程…

Linux驱动开发基础(can)

目录 1.can的介绍 2.can的硬件连接 2.1 CPU自带can控制器 2.2 CPU没有can控制器 3.电气属性 4.can的特点 5.can协议 5.1 can的种类 5.2 数据帧 5.2.1 标准数据帧格式 5.3.1 扩展数据帧格式 5.3 遥控帧 5.4 错误帧 5.5 过载帧 5.6 帧间隔 5.7 位填充 5.8 位时…

leetcode热题100道——字母异位词分组

给你一个字符串数组&#xff0c;请你将 字母异位词 组合在一起。可以按任意顺序返回结果列表。 字母异位词 是由重新排列源单词的所有字母得到的一个新单词。 示例 1: 输入: strs ["eat", "tea", "tan", "ate", "nat", &…

MCU-芯片时钟与总线和定时器关系,举例QSPI

时钟源&#xff1a; 时钟源为系统时钟提供原始频率信号&#xff0c;系统时钟则通过&#xff08;分频、倍频、选择器&#xff09;成为整个芯片的“主时钟”&#xff0c;驱动 CPU 内核、总线&#xff08;AHB、APB&#xff09;及外设的运行。 内部时钟源&#xff1a; HSI&#x…

技术分享 | MySQL内存使用率高问题排查

本文为墨天轮数据库管理服务团队第51期技术分享&#xff0c;内容原创&#xff0c;如需转载请联系小墨&#xff08;VX&#xff1a;modb666&#xff09;并注明来源。 一、问题现象 问题实例mysql进程实际内存使用率过高 二、问题排查 2.1 参数检查 mysql版本 &#xff1a;8.0.…

分享一个精灵图生成和拆分的实现

概述 精灵图&#xff08;Sprite&#xff09;是一种将多个小图像合并到单个图像文件中的技术&#xff0c;广泛应用于网页开发、游戏开发和UI设计中。在MapboxGL中&#xff0c;跟之配套的还有一个json文件用来记录图标的大小和位置。本文分享基于Node和sharp库实现精灵图的合并与…

函数:形参和实参

在函数的使用过程中分为实参和形参&#xff0c;实参是主函数实际调用的值而形参则是给实参调用的值&#xff0c;如果函数没被调用则函式不会向内存申请空间&#xff0c;先用一段代码演示 形参&#xff1a; int test(int x ,int y ) {int z 0;z x y;return z; } 为何会叫做…

【C#知识点详解】ExcelDataReader介绍

今天来给大家介绍一下ExcelDataReader&#xff0c;ExcelDataReader是一个轻量级的可快速读取Excel文件中数据的工具。话不多说直接开始。 ExcelDataReader简介 ExcelDataReader支持.xlsx、.xlsb、.xls、.csv格式文件的读取&#xff0c;版本基本在2007及以上版本&#xff0c;支…

《视觉SLAM十四讲》ch13 设计SLAM系统 相机轨迹实现

前言 相信大家在slam学习中&#xff0c;一定会遇到slam系统的性能评估问题。虽然有EVO这样的开源评估工具&#xff0c;我们也需要自己了解系统生成的trajectory.txt的含义&#xff0c;方便我们更好的理解相机的运行跟踪过程。 项目配置如下&#xff1a; 数据解读&#xff1a; …

在类Unix终端中如何实现快速进入新建目录

&#x1f6aa; 前言 相信喜欢使用终端工作的小伙伴或多或少会被一个小地方给膈应&#xff0c;那就是每次想要新建一个文件夹并且进入之&#xff0c;那么就需要两条指令&#xff1a;mkdir DIR和cd DIR&#xff0c;有些人可能要杠了&#xff0c;我一条指令也能&#xff0c;mkdir…

TG电报群管理机器人定制开发的重要性

在Telegram&#xff08;电报&#xff09;用户突破20亿、中文社群规模持续扩张的背景下&#xff0c;定制化群管理机器人的开发已成为社群运营的战略刚需。这种技术工具不仅解决了海量用户管理的效率难题&#xff0c;更通过智能化功能重构了数字社群的治理范式。本文从管理效能、…

VNA操作使用学习-01 界面说明

以我手里面的liteVNA为例。也可以参考其他的nanoVNA的操作说明。我先了解一下具体的菜单意思。 今天我想做一个天调&#xff0c;居然发现我连一颗基本的50欧姆插件电阻和50欧姆的smt电阻的幅频特性都没有去测试过&#xff0c;那买来这个nva有什么用途呢&#xff0c;束之高阁求…

耘想Docker版Linux NAS的安装说明

耘想LinNAS&#xff08;Linux NAS&#xff09;可以通过Docker部署&#xff0c;支持x86和arm64两种硬件架构。下面讲解LinNAS的部署过程。 1. 安装Docker CentOS系统&#xff1a;yum install docker –y Ubuntu系统&#xff1a;apt install docker.io –y 2. 下载LinNas镜像…

OpenCV图像拼接(4)图像拼接模块的一个匹配器类cv::detail::BestOf2NearestRangeMatcher

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::detail::BestOf2NearestRangeMatcher 是 OpenCV 库中用于图像拼接模块的一个匹配器类&#xff0c;专门用于寻找两幅图像之间的最佳特征点匹配…

不用 Tomcat?SpringBoot 项目用啥代替?

在SpringBoot框架中&#xff0c;我们使用最多的是Tomcat&#xff0c;这是SpringBoot默认的容器技术&#xff0c;而且是内嵌式的Tomcat。 同时&#xff0c;SpringBoot也支持Undertow容器&#xff0c;我们可以很方便的用Undertow替换Tomcat&#xff0c;而Undertow的性能和内存使…

Zabbix安装(保姆级教程)

Zabbix 是一款开源的企业级监控解决方案&#xff0c;能够监控网络的多个参数以及服务器、虚拟机、应用程序、服务、数据库、网站和云的健康状况和完整性。它提供了灵活的通知机制&#xff0c;允许用户为几乎任何事件配置基于电子邮件的告警&#xff0c;从而能够快速响应服务器问…

鸿蒙开发真机调试:无线调试和USB调试

前言 在鸿蒙开发的旅程中&#xff0c;真机调试堪称至关重要的环节&#xff0c;其意义不容小觑。虽说模拟器能够为我们提供初步的测试环境&#xff0c;方便我们在开发过程中快速预览应用的基本效果&#xff0c;但它与真机环境相比&#xff0c;仍存在诸多差异。就好比在模拟器中…

工厂函数详解:概念、目的与作用

一、什么是工厂函数&#xff1f; 工厂函数&#xff08;Factory Function&#xff09;是一种设计模式&#xff0c;其核心是通过一个函数来 创建并返回对象&#xff0c;而不是直接使用 new 或构造函数实例化对象。它封装了对象的创建过程&#xff0c;使代码更灵活、可维护。 二、…

Python简单爬虫实践案例

学习目标 能够知道Web开发流程 能够掌握FastAPI实现访问多个指定网页 知道通过requests模块爬取图片 知道通过requests模块爬取GDP数据 能够用pyecharts实现饼图 能够知道logging日志的使用 一、基于FastAPI之Web站点开发 1、基于FastAPI搭建Web服务器 # 导入FastAPI模…