【Pytorch】大语言模型中的CrossEntropyLoss

news2024/9/20 20:53:04

文章目录

  • 前言
  • 什么是CrossEntropyLoss
  • 语言模型中的CrossEntropyLoss
    • 计算loss的前期准备
    • CrossEntropyLoss的输入
    • CrossEntropyLoss的输出
  • 额外说明

前言

在大语言模型时代,我们常常使用交叉熵损失函数来计算loss,因此,理解该loss的计算流程有助于帮助我们对训练过程有更清晰的认知。本文从以下几个角度介绍nn.CrossEntropyLoss()

  • 使用该函数的前期准备:如何组织函数的输入(logits & labels)
  • 该函数流程
  • 常用参数
  • 该文章内容仅为个人理解,如有误解,欢迎讨论

什么是CrossEntropyLoss

这部分并不是本文的重点,我们仅介绍在语言模型的训练过程中,如何利用该loss

  • 相关信息可见:本人博客
  • 以及官网:CrossEntropyLoss官网

语言模型中的CrossEntropyLoss

计算loss的前期准备

huggingface-transformers源码中,我们在语言模型的forward中总是能看到这样一段函数。我们以LlamaForCausalLM为例:Llama源码

if labels is not None:
  # Shift so that tokens < n predict n
  shift_logits = logits[..., :-1, :].contiguous()
  shift_labels = labels[..., 1:].contiguous()
  # Flatten the tokens
  loss_fct = CrossEntropyLoss()
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
  shift_labels = shift_labels.view(-1)
  # Enable model parallelism
  shift_labels = shift_labels.to(shift_logits.device)
  loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
  output = (logits,) + outputs[1:]
  return (loss,) + output if loss is not None else output

对于Decoder-only模型,在训练时,我们的目标是next token prediction,任务流程如下

  • 假定我们是常规的问答任务,问题是“where is the capital of China“,label为“The capital is Beijing”。该任务的目标为,当输入为“where is the capital of China“时,

  • 我们对question和label进行拼接和tokenize化,一般转化结果 (tokenize忽略) 为:< bos > where is the capital of China < sep > The capital is Beijing < eos >

    • < bos>为句子开头的标志
    • < sep>用于分隔question和label,本质作用是,当模型看到时就知道:问题结束了,下一个token要输出答案了
    • < eos>为生成结束的标志
    • 假定每个词算一个token (忽略空格),那么输入一共有13个token
  • 这时我们将整个序列输入到模型中,模型在每个token的位置都生成一个向量,我们利用lm_head将最后一层的hidden state转化成词表大小的向量logits,用于后续利用Softmax确定每个token的概率

  • 现在模型有了输出logits,怎么计算loss?

    • 对比labels和logits之间的差异来计算loss

    • 现在一共有13个token,生成了13个logits,每个logits都是用于生成next token的。那么很直接的,我们来对比该logits生成的next token准不准就好了

      • 输入:< bos> where is the capital of China < sep> The capital is Beijing < eos>

      • 对比情况为:< sep>->The, The->capital, …, is->Beijing, Beijing->< eos>

        • < sep>对应位置要生成The,…, Beijing对应位置要输出< eos>
      • 我们可以将输入右移一位作为labels: where is the capital of China < sep> The capital is Beijing

        • 可以看到,对于输入来说, < eos>位置没有对应的需要生成的token,因此我们去掉该token
        • 对于labels,< bos>不需要生成,因此我们去掉该token
      • 因此,我们在计算loss时,对logits去尾,labels是输入掐头且右移一位

      • 在代码中对应

          shift_logits = logits[..., :-1, :].contiguous()
          shift_labels = labels[..., 1:].contiguous()
        

CrossEntropyLoss的输入

此时还不能直接将shift_logitsshift_labels进行对比,来计算loss。因为我们上面的操作只是为了<sep> The capital is BeijingThe capital is Beijing <eos>中的token能一一对应起来,对于其他部分生成的token,我们并没有要求(因为不是answer,不需要生成)

  • CrossEntropyLoss函数中有一个参数为ignore_idx默认值为-100。labels值设置为-100的位置不会计算loss
  • 因此我们将除了需要计算loss的位置 (最后5个位置)的labels都设置为-100
  • 最终,需要输入到CrossEntropyLoss中的inputs和labels为
    • inputs为: [, where, is, the, capital, of, China, < sep>, The, capital, is, Beijing]对应的logits
      • 注意:不需要进行Softmax,直接传logits即可,函数内部有更稳定的Softmax计算方式
    • labels为: [-100, -100, -100, -100, -100, -100, -100, The, capital, is, Beijing, < eos>]
    • 我们在训练时,构造输入和labels要注意构造为这种形式

CrossEntropyLoss的输出

默认情况下,输出为mean,即各个token计算得到loss的平均值(在token-level上平均,分母是token的个数)

import torch
import torch.nn as nn

# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])

# 标签,其中第二个样本的标签为 ignore_index (-100)
labels = torch.tensor([0, -100, 2])

# 定义 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(logits, labels)

print(f"Loss: {loss}")
>>> Loss: 0.51058030128479
  • 常用参数:

    • reduction:控制loss的输出形式,共三种'none', 'mean', 'sum',默认为'mean'

      • mean: 每个token计算得到的loss的平均值

      • none: 直接返回每个token计算得到的loss

        • 例子:

          import torch
          import torch.nn as nn
          
          # 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
          logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])
          
          # 标签,其中第二个样本的标签为 ignore_index (-100)
          labels = torch.tensor([0, -100, 2])
          
          # 定义 CrossEntropyLoss
          criterion = nn.CrossEntropyLoss(reduction='none')
          
          # 计算损失
          loss = criterion(logits, labels)
          
          print(f"Loss: {loss}")
          >>> Loss: tensor([0.4170, 0.0000, 0.6041])
          
      • sum: 所有token对应loss求和

额外说明

对最上面的代码补充说明

  shift_logits = shift_logits.view(-1, self.config.vocab_size)
  shift_labels = shift_labels.view(-1)
  • 训练数据往往是按batch组织的,shape为(batch_size, seq_len, vocab_size)
  • 我们将所有batch的token压缩为一个序列,计算整个序列的loss,这样比较方便

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2150049.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ADB 安装教程:如何在 Windows、macOS 和 Linux 上安装 Android Debug Bridge

目录 一、ADB 介绍 二、Windows 系统安装 ADB 1. 下载 ADB 2. 解压文件 3. 验证 ADB 安装 4. 配置环境变量 5. 验证全局 ADB 使用 三、macOS 系统安装 ADB 1. 下载 ADB 2. 解压文件 3. 配置环境变量 4. 验证 ADB 安装 四、Linux 系统安装 ADB 1. 使用包管理器安装…

Pandas和matplotlib实现同期天气温度对比

目录 1、下载近两年的天气Excel数据 2、pandas加载Excel 3、将时间作为索引 4、按日计算最值、均值 5、选取近两年同期温度数据 6、同期温度曲线对比,共享y轴 1、下载近两年的天气Excel数据 一个免费的天气数据下载网址:METAR北京(机场)历史天气 (rp5.ru) 选择”北京天…

20240921 每日AI必读资讯

AI、悟空、西湖文创集盒……2024云栖大会有超多硬核科技&#xff01; - 9月19日&#xff0c;一年一度的阿里云栖大会拉开帷幕 - 阿里现任掌舵者吴泳铭、CTO周靖人携手大模型领域当红炸子鸡月之暗面CEO杨植麟、小鹏汽车CEO何小鹏等一众明星企业创始人给业界带来了一场久违的国…

《 LiteFlow 规则引擎(1) - 入门篇》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

【RabbitMQ】应用

RabbitMQ 应用 1. 七种⼯作模式介绍1.1 Simple(简单模式)1.2 Work Queue(⼯作队列)1.3 Publish/Subscribe(发布/订阅)概念介绍Publish/Subscribe模式 1.4 Routing(路由模式)1.5 Topics(通配符模式)1.6 RPC(RPC通信)1.7 Publisher Confirms(发布确认) 2. ⼯作模式的使⽤案例2.1 …

Java【代码 18】处理Word文档里的Excel表格数据(源码分享)

处理Word文档里的Excel表格数据 1.原始数据2.处理程序2.1 识别替换表格表头2.2 处理多余的换行符2.3 处理后的结果 3.总结 1.原始数据 Word 文档里的 Excel 表格数据&#xff0c;以下仅为示例数据&#xff1a; 读取后的字符串数据为&#xff1a; "姓名\r\n身份证号\r\n手…

【计网】从零开始使用TCP进行socket编程 ---服务端业务模拟Xshell

最糟糕的情况&#xff0c; 不是你出了错&#xff0c; 而是你没有面对出错的勇气。 从零开始使用TCP进行socket编程 1 通信过程的多版本实现1.1 多进程版本1.2 多线程版本 2 服务端业务模拟Xshell2.1 整体框架设计2.2 Command类设计 1 通信过程的多版本实现 在前一篇的文章…

鸿蒙手势交互(三:组合手势)

三、组合手势 由多种单一手势组合而成&#xff0c;通过在GestureGroup中使用不同的GestureMode来声明该组合手势的类型&#xff0c;支持顺序识别、并行识别和互斥识别三种类型。 GestureGroup(mode:GestureMode, gesture:GestureType[]) //- mode&#xff1a;为GestureMode枚…

美元降息,对普通人有哪些影响?

美元降息&#xff0c;对普通人有哪些影响&#xff1f; 美元降息了。很多朋友都说我又不炒股&#xff0c;我手里又没有美金&#xff0c;美元跟我有啥关系啊&#xff1f;那我们就来聊聊美元降息&#xff0c;对我们国内经济到底有哪些影响&#xff1f;你再来看看跟你有没有关系&a…

计算机毕业设计 美发管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

MySQL高阶1890-2020年最后一次登录

目录 题目 准备数据 分析数据 题目 编写解决方案以获取在 2020 年登录过的所有用户的本年度 最后一次 登录时间。结果集 不 包含 2020 年没有登录过的用户。 返回的结果集可以按 任意顺序 排列。 准备数据 Create table If Not Exists Logins (user_id int, time_stamp …

数据库-约束与多表查询

1.约束 例子&#xff1a; 外键约束 例子&#xff1a; 2.多表查询 多表关系 概述 内连接 外连接 自连接 联合查询 子查询 介绍 标量子查询 仅有一个值 列子查询 行子查询 表子查询 练习

【应用开发三】 input子系统介绍

文章目录 1 名词解释2 输入设备编程框架2.1 input子系统2.2 读取数据流程2.3 input_event结构体2.3.1 type&#xff08;哪类事件&#xff09;2.2 code&#xff08;具体事件&#xff09;2.3 value&#xff08;数值&#xff09; 2.4 数据同步2.5 读取start input_event数据 1 名词…

微信小程序如何引入第三方插件

前言 微信的文档不行&#xff0c;我这个&#xff0c;行 如何找到插件管理的页面 扫码登录微信小程序的后台设置页面&#xff0c;点击小程序信息的查看详情&#xff0c;然后点第三方设置 修改app.json 在插件管理的页面添加好要用的插件之后&#xff0c;在插件的详情页面找到…

C++学习指南(六)----list

欢迎来到繁星的CSDN。本期内容主要包括&#xff0c;list的介绍、使用以及与vector的优缺点。 一、什么是list 在先前的C语言学习中&#xff0c;我们接触到了顺序表和链表&#xff0c;而在C中&#xff0c;这正好对应了vector&#xff08;动态增长顺序表&#xff09;和l…

机器学习(西瓜书)第 10 章 降维与度量学习

10.1 k近邻学习kNN k 近邻(k-Nearest Neighbor,简称kNN)学习是一种常用的监督学习方法,其工作机制非常简单&#xff1a;给定测试样本&#xff0c;基于某种距离度量找出训练集中与其最靠近的k个训练样本&#xff0c;然后基于这k个 “邻居”的信息来进行预测.通常&#xff0c;在…

常用排序算法时间复杂度和稳定性

以下是常用排序算法时间复杂度和稳定性&#xff0c;也是常考的&#xff1a;

如何衡量企业品牌力?判断指标有哪些?

企业品牌力是指品牌在市场中的竞争力和影响力&#xff0c;它反映了品牌的价值、知名度、忠诚度、感知质量、差异化以及市场表现等方面。要去衡量一个企业的品牌力&#xff0c;大多从品牌的知名度、忠诚度、所占市场份额、顾客口碑、社媒影响力、品牌资产价值等多方面去判断。我…

【计网】从零开始使用TCP进行socket编程 --- 客户端与服务端的通信实现

阵雨后放晴的天空中&#xff0c; 出现的彩虹很快便会消失。 而人心中的彩虹却永不会消失。 --- 太宰治 《斜阳》--- 从零开始使用TCP进行socket编程 1 TCP与UDP2 TCP服务器类2.1 TCP基础知识2.2 整体框架设计2.3 初始化接口2.4 循环接收接口与服务接口 3 服务端与客户端测试…

Jboss CVE-2015-7501 靶场攻略

漏洞介绍 这是经典的JBoss反序列化漏洞&#xff0c;JBoss在/invoker/JMXInvokerServlet请求中读取了⽤户传⼊的对象&#xff0c;然后我们利⽤Apache Commons Collections中的 Gadget 执⾏任意代码 影响范围 JBoss Enterprise Application Platform 6.4.4,5.2.0,4.3.0_CP10 …