机器学习深度学习——softmax回归的简洁实现

news2024/12/26 22:11:08

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——softmax回归从零开始实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

继续使用Fashion-MNIST数据集,并保持批量大小为256:

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

softmax回归的简洁实现

  • 初始化模型参数
  • 重新审视softmax的实现
    • 数学推导
    • 交叉熵函数
  • 优化算法
  • 训练

初始化模型参数

softmax的输出层是一个全连接层,因此,为了实现模型,我们只需要在Sequential中添加一个带有10个输出的全连接层。当然这里的Sequential并不是必要的,但是他是深度模型的基础。我们仍旧以均值为0,标准差为0.01来随机初始化权重。

# pytorch不会隐式地调整输入的形状
# 因此在线性层前就定义了展平层flatten,来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)  # 给net每一层跑一次init_weights函数

重新审视softmax的实现

数学推导

在之前的例子里,我们计算了模型的输出,然后将此输出送入交叉熵损失。看似合理,但是指数级计算可能会造成数值的稳定性问题。
回想一下之前的softmax函数:
y ^ j = e x p ( o j ) ∑ k e x p ( o k ) 其中 y ^ j 是预测的概率分布, o j 是未规范化的第 j 个元素 \hat{y}_j=\frac{exp(o_j)}{\sum_kexp(o_k)}\\ 其中\hat{y}_j是预测的概率分布,o_j是未规范化的第j个元素 y^j=kexp(ok)exp(oj)其中y^j是预测的概率分布,oj是未规范化的第j个元素
由于o中的一些数值会非常大,所以可能会让其指数值上溢,使得分子或分母变成inf,最后得到的预测值可能变成的0、inf或者nan。此时我们无法得到一个明确的交叉熵值。
提出解决这个问题的一个技巧:在继续softmax计算之前,先从所有的o中减去max(o),修改softmax函数的构造且不改变其返回值:
y ^ j = e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) ∑ k e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) \hat{y}_j=\frac{exp(o_j-max(o_k))exp(max(o_k))}{\sum_kexp(o_j-max(o_k))exp(max(o_k))} y^j=kexp(ojmax(ok))exp(max(ok))exp(ojmax(ok))exp(max(ok))
这样操作以后,可能会使得一些分子的exp(o-max(o))有接近0的值,即为下溢。这些值可能会四舍五入为0,这样就会使得预测值为0,那么此时要是取对数以后就会变为-inf。要是这样反向传播几步,我们可能会发现自己屏幕有一堆的nan。
尽管我们需要计算指数函数,但是我们最终会在计算交叉熵损失的时候会取他们的对数。尽管通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。如下面的式子:
l o g ( y ^ j ) = l o g ( e x p ( o j − m a x ( o k ) ) ∑ k e x p ( o k − m a x ( o k ) ) ) = l o g ( e x p ( o j − m a x ( o k ) ) ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) = o j − m a x ( o k ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) log(\hat{y}_j)=log(\frac{exp(o_j-max(o_k))}{\sum_kexp(o_k-max(o_k))})\\ =log(exp(o_j-max(o_k)))-log(\sum_kexp(o_k-max(o_k)))\\ =o_j-max(o_k)-log(\sum_kexp(o_k-max(o_k))) log(y^j)=log(kexp(okmax(ok))exp(ojmax(ok)))=log(exp(ojmax(ok)))log(kexp(okmax(ok)))=ojmax(ok)log(kexp(okmax(ok)))
通过上式,我们避免了计算单独的exp(o-max(o)),而是直接使用o-max(o)。
因此,我们计算交叉熵函数的时候,传递的不是未规范化的预测o,而不是softmax。
但是我们也希望保留传统的softmax函数,以备我们要评估通过模型输出的概率。

交叉熵函数

在这里介绍一下交叉熵函数,以用于上面推导所需的需求:

torch.nn.CrossEntropyLoss(weight=None,
						ignore_index=-100,
						reduction='mean')

交叉熵函数是将LogSoftMax和NLLLoss集成到一个类中,通常用于多分类问题。其参数使用情况:

ignore_index:指定被忽略且对输入梯度没有贡献的目标值。
reduction:string类型的可选项,可在[none,mean,sum]中选。none表示不降维,返回和target一样的形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。
weight:是一个一维的张量,包含n个元素,分别代表n类的权重,在训练样本不均衡时很有用,默认为None:
(1)当weight=None时,损失函数计算方式为
loss(x,class)=-log(exp(x[class])/Σexp(x[j]))=-x[class]+log(Σexp(x[j])
(2)当weight被指定时,损失函数计算方式为:
loss(x,class)=weight[class]×(-x[class]+log(Σexp(x[j]))

# 在交叉熵损失函数中传递未归一化的预测,并同时计算softmax及其导数
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

# 优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

调用之前定义的训练函数来训练模型:

# 调用之前的训练函数来训练模型
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/796216.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java并发编程第2讲——线程基础

目录 一、线程简介 1.1 什么是线程 1.2 线程的组成 1.3 线程的特点 1.4 Java的main方法 二、线程的创建与启动 2.1 线程的创建 2.1.1 继承Thread类(无返回值) 2.1.2 实现Runnable接口(无返回值) 2.1.3 实现Callable接口…

【Luogu】 P4331 [BalticOI 2004] Sequence 数字序列

题目链接 点击打开链接 题目解法 首先做一个重要的转化:把 b i b_i bi​ 单调上升变为 b i b_i bi​ 单调不降 如何转化?将 a i − i a_i-i ai​−i 变成新的 a i a_i ai​,将 b i − i b_i-i bi​−i 变新的 b i b_i bi​&#xff…

练习时长两年半的双机热备

1.双机热备技术产生的背景 传统的组网方式如下左图所示,内部用户和外部用户的交互报文全部通过防火墙A。如果防火墙A出现故障,内部网络中所有以防火墙A作为默认网关的主机与外部网络之间的通讯将中断,通讯可靠性无法保证。防火墙作为安全设备…

金蝶云星空与旺店通·企业版对接集成物料查询连通创建货品档案(cp_KW货品同步)

金蝶云星空与旺店通企业版对接集成物料查询连通创建货品档案(cp_KW货品同步) 接入系统:金蝶云星空 金蝶K/3Cloud结合当今先进管理理论和数十万家国内客户最佳应用实践,面向事业部制、多地点、多工厂等运营协同与管控型企业及集团公司,提供一个…

前端实现导出excel表格(合并表头)

需求:勾选行导出为excel表格(合并表头 ) 一、安装插件 npm install --save file-saver xlsx运行项目报如下警告的话 运行npm install xlsx0.16.0 --save 来降低版本号(最初我安装的版本号是0.18.16的版本)再次运行项目…

VS构建项目报错信息及解决办法01

报错信息及解决1: 报错信息详情:1>MSVCRT.lib(exe_winmain.obj) : error LNK2019: 无法解析的外部符号 _WinMain16,该符号在函数 "int cdecl scrt_common_main_seh(void)" (?__scrt_common_main_sehYAHXZ) 中被引用 原因&…

SAP RFC介绍(sRFC/aRFC/tRFC/qRFC/pRFC)

异步RFC: aRFC后缀: STARTING NEW TASK CALL FUNCTION - STARTING NEW TASK / RECEIVE / WAIT UNTIL tRFC 后缀: IN BACKGROUND TASK. CALL FUNCTION - IN BACKGROUND TASK qRFC 是tRFC的一个扩展。它允许你将多个tRFC调用序列化为一个…

RocketMQ集群4.9.2升级4.9.6版本

本文主要记录生产环境短暂停机升级RocketMQ版本的过程 一、整体思路 1.将生产环境MQ4.9.2集群同步到测试环境,并启动,确保正常运行。 2.参照4.9.2配置4.9.6集群 3.停掉4.9.2集群,启动4.9.6集群,测试确保正常运行。 4.停掉4.9.6集…

Python Web开发技巧VII

目录 装饰器inject_serializer 装饰器atomic rebase git 清理add的数据 查看git的当前工作目录 makemigrations文件名称 action(detailTrue, methods["GET"]) 如何只取序列化器的一个字段进行返回 Response和JsonResponse有什么区别 序列化器填表和单字段如…

理解Android中不同的Context

作者:两日的blog Context是什么,有什么用 在Android开发中,Context是一个抽象类,它是Android应用程序环境的一部分。它提供了访问应用程序资源和执行各种操作的接口。可以说,Context是Android应用程序与系统环境进行交…

LoadRunner使用教程

1. LoadRunner简介 LoadRunner是一款广泛使用的性能测试工具 可以对各种应用程序进行性能测试,包括Web应用程序、移动应用程序、企业级应用程序等。它提供了一个综合的性能测试解决方案,包括测试计划设计、脚本录制、测试执行、结果分析和报告生成等功…

三、函数-5.流程函数

一、常见函数 【对比】 二、示例 1、if 和 ifnull -- if(value, t, f) 如果value为true,则返回t,否则返回f ok select if(true, ok, error);-- ifnull(value1, value2) 如果value1不为空,返回value1,否则返回value2&#…

MFC表格控件CListCtrl的改造及用法

1、目的 简单描述MFC的表格控件使用方法。Qt适用习惯了以后MFC用的比较别扭,因此记录一下以备后续复制代码使用。由于MFC原生的CListCtrl比较局限,比如无法改变表格的背景色、文字颜色等设定,因此先对CListCtrl类进行重写,以便满足…

哪些报表工具更适合中国企业?看完本文就知道了

企业级报表工具是指能够处理大量数据、支持多种数据源连接、具有强大的数据分析和可视化功能的工具。进入大数据时代,企业数据量剧增、分析需求精细化且要求高效率、高灵活自主性,一般都采用BI报表工具来做智能化、可视化数据分析,推动企业的…

Neo4j数据库中导入CSV示例数据

本文简要介绍Neo4j数据库以及如何从CSV文件中导入示例数据,方便我们快速学习测试图数据库。首先介绍简单数据模型以及基本图查询概念,然后通过LOAD CSV命令导入数据,生成节点和关系。 环境准备 读者可以快速安装Neo4j Desktop,启…

Mysql中(@i:=@i+1)的介绍

i:i1 表达式 生成伪列实现自增序列 语法: select (i:i1) as ,t.* from table_name t,(select i:0) as j (i:i1)代表定义一个变量,每次叠加 1; (select i:0) as j 代表建立一个临时表,j是随便取的表名,但别名一定…

python和c++哪个更值得学,python和c++学哪个简单

大家好,本文将围绕python和c哪个更值得学展开说明,python和c学哪个简单是一个很多人都想弄明白的事情,想搞清楚c和python哪个好学需要先了解以下几个事情。 1、想学编程,选择Python 还是Java或者C? 首先,我…

MySQL索引失效原因及解决方案

MySQL索引失效原因及解决方案 在使用MySQL数据库时,索引是一种重要的性能优化工具。然而,有时候我们可能会遇到索引失效的情况。本文将介绍几种常见的MySQL索引失效原因以及相应的解决方案,并提供SQL语句的错误示例和正确示例。 1. 字符串字…

HarmonyOS学习路之方舟开发框架—学习ArkTS语言(状态管理 二)

Prop装饰器:父子单向同步 Prop装饰的变量可以和父组件建立单向的同步关系。Prop装饰的变量是可变的,但是变化不会同步回其父组件。 概述 Prop装饰的变量和父组件建立单向的同步关系: Prop变量允许在本地修改,但修改后的变化不会…

tinkerCAD案例:11.制作齿轮

tinkerCAD案例:11.制作齿轮 制作齿轮 Add a cylinder to be the main part of the gear. 添加一个圆柱体作为齿轮的主要部分。 说明 Click and drag a cylinder onto the Workplane. 单击圆柱体并将其拖动到工作平面上。 Change the cylinder dimensions to 35mm …