深入解析 Loss 减少方式:mean和sum的区别及其在大语言模型中的应用 (中英双语)

news2024/12/26 16:21:13

深入解析 Loss 减少方式:meansum 的区别及其在大语言模型中的应用

在训练大语言模型(Large Language Models, LLM)时,损失函数(Loss Function)的处理方式对模型的性能和优化过程有显著影响。本文以 reduce_loss 参数为例,详细探讨 meansum 两种方式的定义、适用场景及其对对话模型性能的潜在提升原因,并通过代码实例加深理解。


1. 什么是 reduce_loss

reduce_loss 决定了在每个 batch 中,如何对 token-level 的损失进行归一化或累加处理。常见的选项是:

  • mean: 取每个 token 损失的平均值。
  • sum: 将每个 token 损失直接累加。

参数定义示例(在代码中通过 dataclass 定义):参考来源:https://github.com/allenai/open-instruct

from dataclasses import dataclass, field

@dataclass
class TrainingArguments:
    reduce_loss: str = field(
        default="mean",
        metadata={
            "help": (
                "How to reduce loss over tokens. Options are 'mean' or 'sum'."
                "Using 'sum' can improve chat model performance."
            )
        },
    )

2. meansum 的定义

2.1 mean 模式
  • 定义:将 batch 中所有 token 的损失值取平均。
  • 公式
    Loss mean = ∑ i = 1 N Loss i N \text{Loss}_{\text{mean}} = \frac{\sum_{i=1}^{N} \text{Loss}_i}{N} Lossmean=Ni=1NLossi
    其中 ( N N N) 是当前 batch 中的 token 总数。
  • 特性:每个 token 的损失对最终的 loss 贡献相等,损失值与 batch 中的 token 数无关。
2.2 sum 模式
  • 定义:将 batch 中所有 token 的损失值直接累加。
  • 公式
    Loss sum = ∑ i = 1 N Loss i \text{Loss}_{\text{sum}} = \sum_{i=1}^{N} \text{Loss}_i Losssum=i=1NLossi
  • 特性:长序列(更多 token)的损失对总 loss 的贡献更大,损失值直接与 token 数成正比。

3. meansum 的区别

模式特点优点缺点
mean损失对 token 数归一化,独立于 batch size。稳定性强,适用于 token 数差异大的批次。长序列与短序列对损失的贡献相同,可能弱化长序列的重要性。
sum损失值与 token 总数成正比,长序列贡献更大。在注重长序列表现的任务中效果更好(如对话生成)。损失值随 batch size 变化波动,需要动态调整学习率。

4. 适用场景分析

4.1 mean
  • 适用任务:大多数语言建模任务,如 GPT 或 BERT 的预训练。
  • 适用场景:当训练数据中序列长度差异较大时,mean 可以避免因长序列的损失值过大而导致梯度更新不均衡。
4.2 sum
  • 适用任务:对长序列表现要求较高的任务,如对话生成(Chat Models)和长文本生成。
  • 适用场景:长序列的损失占比更高,从而使优化过程更加关注全局上下文的建模。

5. 为什么 sum 能提升对话模型性能?

对话模型(Chat Models)的训练中,长序列往往包含丰富的上下文信息,而短序列则可能无法体现模型的上下文理解能力。在 sum 模式下:

  1. 长序列的重要性增加:长序列的损失对总损失的贡献更大,这促使模型更关注上下文的建模。
  2. 对全局一致性更敏感sum 模式下,模型的优化方向更倾向于全序列的一致性,特别适合需要长距离依赖的任务。

示例
假设一个 batch 包含以下两个样本:

  • 样本 A: 长度为 10,损失总和为 5。
  • 样本 B: 长度为 50,损失总和为 25。

计算损失贡献:

  • mean 模式
    Loss mean = 5 + 25 10 + 50 = 0.5 \text{Loss}_{\text{mean}} = \frac{5 + 25}{10 + 50} = 0.5 Lossmean=10+505+25=0.5
    样本 A 和 B 的贡献权重相同。
  • sum 模式
    Loss sum = 5 + 25 = 30 \text{Loss}_{\text{sum}} = 5 + 25 = 30 Losssum=5+25=30
    样本 B 的贡献权重显著增加,优化更关注长序列。

6. 实战代码

以下是一个完整的训练脚本,展示如何在 Hugging Face 的 transformers 框架中使用 reduce_loss 参数。

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch

# 模型和数据集
model_name = "meta-llama/Llama-3.1-8B"
dataset_name = "allenai/tulu-3-sft-mixture"

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

dataset = load_dataset(dataset_name)
tokenized_dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding="max_length"), batched=True)
train_loader = DataLoader(tokenized_dataset["train"], batch_size=2, shuffle=True)

# 训练设置
reduce_loss = "sum"  # 改为 "mean" 可对比效果
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 训练循环
for epoch in range(2):
    for batch in train_loader:
        inputs = torch.tensor(batch["input_ids"]).to(device)
        labels = inputs.clone()
        outputs = model(inputs, labels=labels)

        if reduce_loss == "sum":
            loss = outputs.loss.sum()
        else:  # 默认 "mean"
            loss = outputs.loss.mean()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch: {epoch}, Loss: {loss.item()}")

7. 注意事项与优化建议

  1. 动态调整学习率

    • 使用 sum 时,由于损失值放大,建议适配学习率,如降低到 mean 模式的 ( 1 / N 1/N 1/N )。
    • 配合学习率调度器(如 linear)优化训练。
  2. 对长短序列的平衡

    • 若长序列权重过大导致模型性能退化,可结合 curriculum learning 或混合训练策略(如对长短序列按比例采样)。
  3. 性能评估

    • 在验证集上,关注长序列和短序列的生成性能对比。

8. 总结

reduce_loss 的选择对模型性能有直接影响:

  • mean 更通用,适合大多数语言建模任务。
  • sum 在对话生成等长序列敏感任务中表现更优。

希望本文能为 LLM 研究人员提供思路和参考,在具体任务中灵活选择合适的损失归一化方式,从而提升模型性能。

Understanding the Difference Between mean and sum Loss Reduction in LLM Training

When training large language models (LLMs), the way token-level loss is reduced across a batch can significantly impact optimization and model performance. This article delves into the reduce_loss parameter, exploring the differences between mean and sum reduction modes, their definitions, use cases, and why sum might improve the performance of chat-oriented models. Practical code examples are also provided for clarity.


1. What is reduce_loss?

The reduce_loss parameter determines how the token-level loss values in a batch are aggregated. The two most common options are:

  • mean: Averages the loss over all tokens in a batch.
  • sum: Sums the loss of all tokens in a batch.

Example definition (from the codebase using Python dataclass):

from dataclasses import dataclass, field

@dataclass
class TrainingArguments:
    reduce_loss: str = field(
        default="mean",
        metadata={
            "help": (
                "How to reduce loss over tokens. Options are 'mean' or 'sum'."
                "Using 'sum' can improve chat model performance."
            )
        },
    )

2. Definitions of mean and sum

2.1 mean
  • Definition: Averages the loss across all tokens in a batch.
  • Formula:
    Loss mean = ∑ i = 1 N Loss i N \text{Loss}_{\text{mean}} = \frac{\sum_{i=1}^{N} \text{Loss}_i}{N} Lossmean=Ni=1NLossi
    where ( N N N ) is the total number of tokens in the batch.
  • Characteristics: The contribution of each token to the final loss is normalized, making the loss independent of the batch’s token count.
2.2 sum
  • Definition: Sums up the loss across all tokens in a batch.
  • Formula:
    Loss sum = ∑ i = 1 N Loss i \text{Loss}_{\text{sum}} = \sum_{i=1}^{N} \text{Loss}_i Losssum=i=1NLossi
  • Characteristics: The total loss is proportional to the number of tokens, giving longer sequences more weight in the optimization process.

3. Key Differences Between mean and sum

Reduction ModeCharacteristicsAdvantagesDisadvantages
meanNormalizes the loss by token count.Stable and robust for datasets with variable-length sequences.Long sequences are underweighted relative to short ones.
sumLoss scales with the number of tokens.Places greater emphasis on longer sequences, improving performance in tasks requiring context modeling.Loss values vary with batch size, necessitating dynamic learning rate adjustment.

4. Use Cases for mean and sum

4.1 mean
  • Best Suited For: Pretraining or general language modeling tasks like GPT or BERT.
  • Scenario: When the dataset contains sequences of widely varying lengths, mean ensures that longer sequences do not disproportionately influence gradient updates.
4.2 sum
  • Best Suited For: Tasks requiring high performance on long sequences, such as dialogue generation or document-level text generation.
  • Scenario: Encourages the model to prioritize sequences with richer contexts, as their loss contributes more to the overall optimization.

5. Why Does sum Improve Chat Model Performance?

In chat-oriented models, sequences are typically longer and require the model to understand and generate coherent responses over extended contexts. Using sum mode:

  1. Enhances Long Sequence Weighting: Longer sequences contribute more to the total loss, emphasizing the importance of context modeling.
  2. Encourages Global Consistency: By assigning more weight to longer contexts, the model better captures dependencies across the entire sequence.
  3. Balances Token Importance: Since chat models are often evaluated on dialogue-level coherence, sum ensures that tokens from the context and the response are proportionally weighted.

Example:
Consider a batch with two samples:

  • Sample A: Sequence length = 10, loss = 5.
  • Sample B: Sequence length = 50, loss = 25.

Loss calculations:

  • mean mode:
    Loss mean = 5 + 25 10 + 50 = 0.5 \text{Loss}_{\text{mean}} = \frac{5 + 25}{10 + 50} = 0.5 Lossmean=10+505+25=0.5
    Both samples contribute equally to the loss.
  • sum mode:
    Loss sum = 5 + 25 = 30 \text{Loss}_{\text{sum}} = 5 + 25 = 30 Losssum=5+25=30
    Sample B contributes much more to the total loss, focusing the optimization on longer contexts.

6. Practical Implementation

Here’s a practical training script that demonstrates the use of reduce_loss in both modes.

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch

# Model and dataset
model_name = "meta-llama/Llama-3.1-8B"
dataset_name = "allenai/tulu-3-sft-mixture"

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

dataset = load_dataset(dataset_name)
tokenized_dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding="max_length"), batched=True)
train_loader = DataLoader(tokenized_dataset["train"], batch_size=2, shuffle=True)

# Training setup
reduce_loss = "sum"  # Change to "mean" to compare effects
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
for epoch in range(2):
    for batch in train_loader:
        inputs = torch.tensor(batch["input_ids"]).to(device)
        labels = inputs.clone()
        outputs = model(inputs, labels=labels)

        if reduce_loss == "sum":
            loss = outputs.loss.sum()
        else:  # Default: "mean"
            loss = outputs.loss.mean()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch: {epoch}, Loss: {loss.item()}")

7. Practical Considerations

  1. Learning Rate Adjustment:

    • When using sum, the loss magnitude increases with batch size, so you may need to adjust the learning rate (e.g., scale it down by ( 1 / N 1/N 1/N )).
  2. Balancing Long and Short Sequences:

    • Overweighting long sequences can sometimes harm generalization. Using curriculum learning or sampling strategies (e.g., proportional sampling) can help mitigate this.
  3. Validation:

    • Evaluate model performance on both short and long sequences to confirm improvements in the intended metrics.

8. Conclusion

The choice between mean and sum loss reduction modes depends on the specific task and dataset:

  • Use mean for general-purpose language modeling tasks where sequence lengths vary significantly.
  • Use sum for tasks that prioritize long-sequence performance, such as chat models or long-text generation.

Understanding and experimenting with these settings can lead to better-optimized models, particularly in the nuanced field of LLM fine-tuning.

后记

2024年12月3日16点04分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2254099.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL索引(三):选错索引

优化器选择索引的目的,是找到一个最优的执行方案,并用最小的代价去执行语句。 思考 假设有表结构: -- T表结构: CREATE TABLE t (id int(11) NOT NULL,a int(11) DEFAULT NULL,b int(11) DEFAULT NULL,PRIMARY KEY (id),KEY a (…

区块链学习笔记(2)--区块链的交易模型part1

模型基础 区块链的tx分为两种模型,分别是比特币为代表的UTXO(Unspent Transaction Output)模型,和以太坊为代表的Account模型。前者适用于货币记账,后者适用于链上应用。 UTXO模型 类似于现金的交易模型 一个tx包含…

Redis 基础、Redis 应用

Redis 基础 什么是 Redis? Redis (REmote DIctionary Server)是一个基于 C 语言开发的开源 NoSQL 数据库(BSD 许可)。与传统数据库不同的是,Redis 的数据是保存在内存中的(内存数据库&#xf…

php7.4安装pg扩展-contos7

今天接到一个需求,就是需要用thinkphp6链接pg(postgresql)数据库。废话不多说,直接上操作步骤 一、安装依赖 yum install -y sqlite-devel libxml2 libxml2-devel openssl openssl-devel bzip2 bzip2-devel libcurl libcurl-devel libjpeg libjpeg-dev…

Linux中的常用基本指令(下)

Linux常用基本指令 Linux中的基本指令12.head指令13.tail指令简单解释重定向与管道(重要) 14.date指令(时间相关的指令)15.cal指令(不重要)16.find指令(灰常重要)17.grep指令(重要)18.which指令和alias指令19.zip/unzip指令:20.tar指令(重要&…

Android 还在使用LogCat打日志?XLog框架;日志打印到控制台,打印到文件中。

目录: 为什么要打印日志?XLog是什么XLog如何使用 一、为什么要打印日志? 日志是我们系统出现错误时,最快速有效的定位工具,没有日志给出的错误信息,遇到报错你就会一脸懵逼;而且日志还可以用来…

zabbix“专家坐诊”第266期问答

问题一 Q:zabbix编译升级主要工作是不是将PHP,nginx,zabbix都重新编译安装一遍,细节的先不说 A:升级zabbix就可以 Q:这个OID是哪个OID A:mib文件里面有个snmp oid的值 那个就是oid。https://blog.csdn.net/qq_508853…

第八课 Unity编辑器创建的资源优化_特效篇(Particle System)详解

无论是CPU还是GPU,粒子系统对其的影响面都是不容小觑的。随着项目的重度化和3A化,玩家的口味变挑剔了、游戏玩法复杂度变高了、画面的特效表现变复杂了......所以我们还是更加谨慎地对待粒子系统。 特效(Particle System) 游戏效…

王道考研编程题总结

我还在完善中,边复习边完善(这个只是根据我自身总结的) 一、 线性表 1. 结构体 #define MaxSize 40 typedef struct{ElemType data[MaxSize];int length; }SqList 2. 编程题 1. 删除最小值 题意 :从顺序表中删除…

ubuntu20.04安装OpenPcdet,CUDA版本11.8,显卡4090

本文参考这2篇文章的内容:https://blog.csdn.net/jin15203846657/article/details/122735375#comments_25352667 https://zhuanlan.zhihu.com/p/642158810 记录了自己安装OpenPcdet的过程。 OpenPcdet的安装需要cuda和pytorch版本严格关联。本例的CUDA版本&#xf…

初识EasyFramework

一、获取EF Git地址:https://github.com/HiWenHao/EFrameworkGitee地址:https://gitee.com/wang_xiaoheiiii/EFramework视频合集:EasyFramework介绍_哔哩哔哩_bilibiliQQ群: 711540505 二、 下载并初步了解 1. 下载完成后,可以看…

爬虫获取的数据如何用于市场分析

目录 一、网络爬虫基础 HTML解析器 API接口 数据库抓取 二、数据预处理 数据清洗 数据转换 数据整合 三、市场分析应用 消费者行为分析 竞争对手分析 市场趋势预测 四、案例分析 数据获取 数据预处理 市场分析 总结 在当今数据驱动的商业环境中,市…

C++小碗菜之二:软件单元测试

“没有测试的代码重构不能称之为重构,它仅仅是垃圾代码的到处移动” ——Corey Haines 目录 前言 什么是单元测试? 单元测试的组成 单元测试的命名 单元测试的独立性 Google Test 单元测试的环境配置与使用 1. Ubuntu下安装 Google Test 2. 编写…

Kubernetes架构原则和对象设计

云原生学习路线导航页(持续更新中) 快捷链接 Kubernetes常见问题解答 本文从 Google Borg系统的架构设计开始,深入讲解Kubernetes架构及组件的基本原理 1.什么是云计算 1.1.传统行业应用 假设有10台服务器,两个应用。小规模管…

力扣-图论-1【算法学习day.51】

前言 ###我做这类文章一个重要的目的还是给正在学习的大家提供方向和记录学习过程(例如想要掌握基础用法,该刷哪些题?)我的解析也不会做的非常详细,只会提供思路和一些关键点,力扣上的大佬们的题解质量是非…

学习笔记056——Docker日志的清理问题

文章目录 Docker日志的清理问题1、Docke日志所在位置2、日志清理 Docker日志的清理问题 Ubuntu上部署Docker,运行一段时间后,会累计很多的日志量。 如果不及时处理,会占用系统空间,影响系统性能。 如何处理日志累计过大的问题&…

Python3:Pytest框架parametrize报错in “parametrize“ the number of names (4)

Python3:Pytest框架parametrize报错in “parametrize“ the number of names (4) 排查原因:是pytest入参时,需要4个参数,但是提供了3个参数 test_tenant_list:- ["http://xx:8081/scheduler/v1/tenancy/list",{"co…

Linux 35.6 + JetPack v5.1.4之RTP实时视频Python框架

Linux 35.6 JetPack v5.1.4之RTP实时视频Python框架 1. 源由2. 思路3. 方法论3.1 扩展思考 - 慎谋而后定3.2 扩展思考 - 拒绝拖延或犹豫3.3 扩展思考 - 哲学思考3.4 逻辑实操 - 方法论 4 准备5. 分析5.1 gst-launch-1.05.1.1 xvimagesink5.1.2 nv3dsink5.1.3 nv3dsink sync05…

GIt (一) Git的安装,项目搭建,远程仓库,分支

文章目录 一、 版本控制1.1 集中式版本控制1.2 分布式版本控制 二、 Git的安装及配置2.1 安装2.2 Git的配置2.2 查看配置 三、 Git基本理论3.1 工作区域3.2 文件状态 四、Git项目的搭建与操作4.1 初始化Git仓库4.2 常见的操作4.2.1 文件添加到暂存区4.2.2 文件提交更新4.2.3 查…

iview upload clearFiles清除回显视图

iview upload 上传完文件之后清除内容&#xff0c;打开会回显视图&#xff0c;清除不掉 关闭弹框时主动清除回显内容即可this.$refs.uploads.clearFiles() <FormItem label"上传附件:" :label-width"formNameWidth"><Upload action"/fms/ap…