交叉熵函数与kl散度的区别

news2024/9/26 15:14:55

公式上的区别

手动计算的方式展示如何实现这两个损失函数

交叉熵损失函数

import torch
import torch.nn.functional as F

# 模型的输出 logits 和真实标签 target
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
target = torch.tensor([0, 1])  # 真实标签

# 计算 softmax 以获得预测概率
pred_probs = F.softmax(logits, dim=1)

# 将 target 转换为 one-hot 编码
target_one_hot = F.one_hot(target, num_classes=logits.size(1))

# 交叉熵损失公式:L = - Σ y * log(ŷ)
cross_entropy_loss = - torch.sum(target_one_hot * torch.log(pred_probs)) / logits.size(0)

print('手动实现的交叉熵损失:', cross_entropy_loss)

kl散度

import torch
import torch.nn.functional as F

# 模型的输出 logits 和目标分布 target_probs
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
target_probs = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.7, 0.2]])  # 目标分布

# 将 logits 转换为 log softmax
logits_log_softmax = F.log_softmax(logits, dim=1)

# KL 散度公式:D_KL(P || Q) = Σ P * (log P - log Q)
kl_div_loss = torch.sum(target_probs * (torch.log(target_probs) - logits_log_softmax)) / logits.size(0)

print('手动实现的KL散度:', kl_div_loss)

官方打包好的函数

交叉熵损失 (Cross Entropy Loss) 官方实现

import torch
import torch.nn as nn

# 创建交叉熵损失函数实例
cross_entropy_loss_fn = nn.CrossEntropyLoss()

# 假设模型的输出 logits 和真实标签 targets
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
targets = torch.tensor([0, 1])  # 真实标签 (整数形式)

# 计算交叉熵损失
loss = cross_entropy_loss_fn(logits, targets)
print('官方交叉熵损失:', loss)

KL 散度损失 (KL Divergence Loss) 官方实现

import torch
import torch.nn as nn
import torch.nn.functional as F

# 创建KL散度损失函数实例
kl_div_loss_fn = nn.KLDivLoss(reduction='batchmean')  # 使用 'batchmean' 计算每个样本的平均损失

# 假设模型的输出 logits 和目标分布 target_probs
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
target_probs = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.7, 0.2]])  # 目标分布 (已经是 softmax 概率)

# 计算 log softmax 以用于 KL 散度
logits_log_softmax = F.log_softmax(logits, dim=1)

# 计算KL散度损失
kl_loss = kl_div_loss_fn(logits_log_softmax, target_probs)
print('官方KL散度损失:', kl_loss)

加上温度

是的,KL 散度在深度学习中,尤其是知识蒸馏(Knowledge Distillation)中,常常与温度参数(Temperature, TTT)结合起来使用。温度调节可以让模型的预测分布更加平滑,从而在蒸馏过程中更有效地传递知识。下面将解释为什么 KL 散度与温度结合,以及如何使用温度参数。

温度在知识蒸馏中的作用

在知识蒸馏中,通常有一个教师模型(Teacher Model)和一个学生模型(Student Model)。教师模型的输出概率分布用来指导学生模型的训练,但直接使用教师模型的概率分布往往过于“尖锐”(即,教师模型的 softmax 输出大部分概率集中在正确类别)。为了使分布更加平滑,加入了温度参数。

温度对 softmax 的影响

softmax 函数将模型的 logits 转换为概率分布。加上温度参数 TTT 后的 softmax 表达式为:

  • 当 T=1T = 1T=1,softmax 正常工作,输出标准的概率分布。
  • 当 T>1T > 1T>1,softmax 输出变得更加平滑(分布更加均匀)。
  • 当 T<1T < 1T<1,softmax 输出更加“尖锐”,即概率分布更加接近 one-hot 编码。

在知识蒸馏中,通过引入较高的温度 TTT,可以让教师模型输出的概率分布变得更加平滑,从而包含更多类的信息,帮助学生模型更好地学习。

温度结合 KL 散度

在知识蒸馏过程中,学生模型通常通过最小化学生模型与教师模型之间的KL 散度来学习教师模型的输出分布。引入温度后,KL 散度的损失计算公式如下:

其中:

  • T2是为了平衡梯度的影响,避免由于高温度导致的梯度缩小。
  • DKL​ 表示 KL 散度,用于比较教师模型和学生模型的概率分布。

PyTorch 实现带温度的 KL 散度

你可以在 PyTorch 中手动实现带温度参数的 KL 散度,如下所示:

import torch
import torch.nn.functional as F

def distillation_kl_divergence_loss(logits_student, logits_teacher, temperature):
    """
    计算带温度参数的KL散度,用于知识蒸馏
    :param logits_student: 学生模型的logits (未经过softmax)
    :param logits_teacher: 教师模型的logits (未经过softmax)
    :param temperature: 温度参数T
    :return: 知识蒸馏中的KL散度损失
    """
    # 计算log softmax(学生和教师模型的输出都经过温度缩放)
    log_probs_student = F.log_softmax(logits_student / temperature, dim=1)
    probs_teacher = F.softmax(logits_teacher / temperature, dim=1)

    # 计算KL散度损失,并乘以T^2
    kl_div_loss = F.kl_div(log_probs_student, probs_teacher, reduction='batchmean') * (temperature ** 2)
    
    return kl_div_loss

# 示例
logits_student = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
logits_teacher = torch.tensor([[2.1, 0.6, 0.1], [0.2, 2.6, 0.7]], requires_grad=False)
temperature = 2.0  # 温度参数

# 计算带温度的KL散度损失
kl_loss = distillation_kl_divergence_loss(logits_student, logits_teacher, temperature)
print('带温度的KL散度损失:', kl_loss)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2125889.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

测试开发基础——软件测试中的bug

二、软件测试中的Bug 1. 软件测试的生命周期 软件测试贯穿于软件的整个生命周期 需求分析 测试计划 测试设计与开发 测试执行 测试评估 上线 运行维护 用户角度&#xff1a;软件需求是否合理 技术角度&#xff1a;技术上是否可行&#xff0c;是否还有优化空间 测试角度…

某郊到家:互联网时代下的按摩服务革新

在快速发展的时代背景下&#xff0c;一群具有前瞻性的企业家在2018年勇敢地进军了按摩服务行业&#xff0c;引领了一场对传统模式的革新。他们不仅在竞争激烈的市场中站稳脚跟&#xff0c;还成功地在不断变化的市场环境中确立了自己的位置。 创新的商业模式和持续的努力&#x…

【2024】前端学习笔记2-有序列表-无序列表-描述列表

学习笔记 有序列表:ol基本使用:嵌套使用扩展:使用CSS改变序号类型无序列表:ul基本使用扩展:使用CSS改变符号类型扩展:使用CSS定制列表样式描述列表:dl基本使用扩展:使用CSS定制类型格式总结有序列表:ol 有序列表由<ol>标签包裹一组<li>(列表项)标签组成…

区块链之变:揭秘Web3对互联网的全面改变

随着技术的进步&#xff0c;区块链 逐渐从一个相对小众的概念演变为重塑互联网结构的核心力量。特别是 Web3 的兴起&#xff0c;标志着互联网进入了一个新的发展阶段。这一变革不仅仅是技术的升级&#xff0c;更是对互联网功能、数据控制和用户体验的全面重新定义。本文将详细探…

数学建模笔记—— 回归分析

数学建模笔记—— 回归分析 回归分析1. 回归分析的一般步骤2. 一元线性回归分析2.1 具体过程2.1.1 确定回归方程中的解释变量和被解释变量2.1.2 确定回归模型和建立回归方程2.1.3 利用回归直线进行估计和预测2.1.4 对回归方程进行各种检验(补充)1. 回归直线的拟合优度2. 显著性…

哪款提醒软件能清晰展示每日工作任务?

在快节奏的工作环境中&#xff0c;每天的工作任务堆积如山&#xff0c;如何有效地整理和清晰查看这些任务&#xff0c;成为了提高工作效率的关键。一款优秀的提醒软件能够帮助我们将任务条理化&#xff0c;确保每一项工作都能按时完成。 敬业签就是这样一款能够清晰展示每日工…

VR 尺寸美学主观评价-解决方案-现场体验研讨会报名

棣拓科技VR创新解决方案助力尺寸美学所见即所得! 诚邀各位行业专家莅临指导交流 请扫描海报二维码踊跃报名&#xff0c;谢谢 中国上海 2024.10.25 亮点介绍 1、通过精湛渲染技术&#xff0c;最真实展现设计效果&#xff0c;并通过VR设备一比一比例进行展现。 2、设置相关设…

suid提权的环境搭建+反弹shell

SUID&#xff08;Set User ID&#xff09;是一种特殊的文件权限设置&#xff0c;它允许文件在执行时具有文件所有者的权限。当具有SUID权限的文件被执行时&#xff0c;执行该文件的用户会暂时获得文件所有者的权限。这种权限通常用于需要高权限操作的程序&#xff0c;如‌passw…

建筑用能该如何统一管理?水电气集抄太麻烦?!看看这个吧!建筑能耗分析管理系统 您的运维“好帮手”

安科瑞刘鸿鹏 随着工业化和信息化进程的加速&#xff0c;企业对能源管理的需求愈加迫切。安科瑞电气股份有限公司推出的Acrel-5000能耗管理系统运用物联网技术&#xff0c;实时采集电表、水表、燃气表等能源计量仪表的数据&#xff0c;并结合大数据技术进行处理和存储。该平台旨…

uniapp(H5)设置反向代理,设置成功后页面报错

设置反向代理后&#xff0c;页面报错图&#xff1a; 反向代理代码&#xff1a;devServer下面就是配置对应的代理&#xff0c;一般这样就没问题了 "h5": {"router": {"mode": "hash"},"devServer": {"port": 517…

基于SpringBoot+Vue的古诗词学习软件系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于JavaSpringBootVueMySQL的古诗词学…

9月编程排行榜来了!C语言跌至历史最低!

9月的编程语言排行榜终于出炉&#xff0c;令人意外的是&#xff0c;曾经风靡全球、无数开发者的首选——C语言&#xff0c;竟然跌到了历史最低点&#xff01;这一变化引发了整个编程社区的广泛关注和讨论。 大家周三好呀&#xff01;又来到金秋九月&#xff0c;又到了TIOBE编程…

【MATLAB源码-第264期】基于matlab的跳频通信系统仿真,采用MSK调制方式,差分解调;输出误码率曲线和各节点波形图。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 跳频通信系统是一种能够提高通信抗干扰能力的技术&#xff0c;它通过在传输过程中不断地改变载波频率来避开干扰或者窃听。在这套跳频通信系统中&#xff0c;我们采用了最小频移键控&#xff08;MSK&#xff09;作为调制方式…

Jenkins 详解,几分钟学会,自动编译/部署/发布软件

大家好&#xff0c;欢迎来到停止重构的频道。 本期我们详细讨论Jenkins。 随着互联网应用越来越多&#xff0c;系统规模也越来越大&#xff0c;DevOps、CI/CD等概念也被重视起来&#xff0c;持续交付/持续集成/自动化部署等理念也被越来越多的团队接受。 而本期介绍的Jenkin…

智能交通(三)——Elsevier特刊推荐

特刊征稿 01 期刊名称&#xff1a; Vehicular Communications 特刊名称&#xff1a; Computational Aspects of Vehicular Networks 截止时间&#xff1a; 论文提交日期:2024年7月21日 录用通知:2024年9月30日 期末论文:2024年10月30日 目标及范围&#xff1a; 主题包括…

windows10下本机FTP服务搭建教程

文章目录 前言一、FTP服务器简介二、开启FTP服务站点&#xff08;所有用户可访问&#xff09;1.安装FTP服务2.配置FTP服务器3.本机访问ftp服务 三、开启FTP服务站点&#xff08;指定用户可访问&#xff09;1.创建本地用户2.添加FTP站点3.本机访问ftp服务 总结 前言 ftp服务器主…

Linux——分离部署,分化压力

PQS/TPS 每秒请求数/ 每秒事务数 // 流量衡量参数 可以根据预估QPS 和 服务器的支持的最高QPS 对照计算 就可以得出 需要上架的服务器的最小数量 PV 页面浏览数 UV 独立用户访问量 // 对于网站的总体访问量 response time 响应时间 // 每个请求的响应时间…

828华为云征文 | Flexus X实例与Harbor私有镜像仓库的完美结合

前言 华为云828企业上云节&#xff0c;Flexus X实例携手Harbor私有镜像仓库&#xff0c;共创云上安全高效新生态&#xff01;Flexus X以其卓越性能与稳定性&#xff0c;为Harbor提供了理想的运行环境。Harbor作为领先的私有镜像仓库&#xff0c;与Flexus X完美结合&#xff0c;…

[OpenGL]使用OpenGL绘制三角形

一、简介 本文介绍了如何在linux/win(wsl2)环境下&#xff0c;使用GLFWGLAD实现绘制三角形。 本文内容基本根据LearnOpengGL-入门-你好&#xff0c;三角形整理完成&#xff0c;读者也可以参考LearnOpengGL-入门-你好&#xff0c;三角形自行学习如何使用OpenGL绘制三角形。 按…

【人工智能学习笔记】3_2 机器学习基础之机器学习经典算法介绍

线性回归算法的定义和任务类型 定义:线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法任务类型:回归应用场景:异常指标监控 农业贷款监控过拟合和欠拟合 定义:过拟合和欠拟合用来度量模型泛化能力的直观表现欠拟合:模型…