深度学习中,用损失的均值或者总和反向传播的区别

news2024/12/24 17:37:34

如深度学习中代码:

def train_epoch_ch3(net, train_iter, loss, updater):
    """The training loop defined in Chapter 3."""
    # Set the model to training mode
    if isinstance(net, torch.nn.Module):
        net.train()
    # Sum of training loss, sum of training accuracy, no. of examples
    metric = Accumulator(3)
    for X, y in train_iter:
        # Compute gradients and update parameters
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # Using PyTorch in-built optimizer & loss criterion
            updater.zero_grad()
            l.backward()
            updater.step()
            metric.add(float(l) * len(y), accuracy(y_hat, y), y.numel())
        else:
            # Using custom built optimizer & loss criterion
            l.sum().backward()
            updater(X.shape[0])
            metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # Return training loss and training accuracy
    return metric[0] / metric[2], metric[1] / metric[2]

对于循环for X, y in train_iter:每一次迭代都是一次小批量训练 

当用torch里的优化器来更新参数时,第一个if语句执行:

1.此时由于torch默认累计梯度,所以每次循环都得梯度置零

2.然后torch进行反向传播(如果是使用了torch定义的loss,那么得到的l是标量,即平均损失,次此时直接反向传播即可)。

loss = nn.CrossEntropyLoss(reduction='mean'),reduction='mean'意味着损失会返回均值,reduction理解成降维

loss = nn.CrossEntropyLoss(reduction='None'),reduction='None'意味着损失会返回原来的形式,即矢量

为什么不用总和形式呢?均值会好一点,因为每次训练有不同的batch_size(如果我把样本分成三个批量,一次训练中,首先用批量一训练,然后批量二,批量三,然后批量都训练完后,再进行新的一轮训练,直到收敛),用均值的话会默认使用l.sum()/batch_size,batch_size由内部求出,这个不用自己写,就实现了batch_size的解耦,不容易出错,更新参数时形式就变成了w_i=w_i-\alpha *loss,因为此时的损失是均值了

3.进行参数更新

4.统计损失与精度,为什么要用float(l)*len(y)呢

 因为l=loss(y_hat,y),这里的损失函数是从外部传进来的,所以如果传进来的是torch自己定义的损失函数,应该计算的是平均损失,那么要变成整体损失就得float(l)*len(y)。

如果loss是自己定义的,可能只是计算了损失,没求平均(看上面的函数,此时的loss应该是矢量,所以才会有l.sum())

当自己实现优化器来更新参数时,else语句执行:

1.此时自己可能没实现累计梯度,所以不用梯度清零

2.反向传播,但由于此时自己实现了loss函数,可能loss是矢量,要转成标量来反向传播,一个矢量转标量的好方法就是求和,即l.sum().backward(),

3.更新参数,注意到updater传入了一个参数X.shape[0],即样本数量batch_size,外边要注意更新参数时为w_i =w_i-\alpha /batchSize*loss,因为此时的l是总和

4.统计损失与精度

sum和mean其实主要影响了“梯度的大小”,反向传播时,依据损失求梯度,如果是sum,则梯度会比mean大n倍,那么在学习率不变的情况下,步子会迈得很长,体现到图形上就是正确率提升不了。所以需要缩小学习率。

l.mean().backward()和l.sum().backward()的区别在于它们计算梯度的方式。l.mean().backward()计算的是平均损失对权重的梯度,而l.sum().backward()计算的是总损失对权重的梯度。

在实践中,这两种方式通常会得到相似的结果,因为它们都是在尝试最小化损失。然而,使用l.mean().backward()可能会使得梯度的大小更稳定,因为它不会因为批量大小的变化而变化。这可能会使得训练过程更稳定,特别是在批量大小可能变化的情况下。

另一方面,使用l.sum().backward()可能会使得梯度的大小更大,这可能会导致训练过程更快,但也可能导致训练过程更不稳定。

总的来说,哪种方式更好取决于你的具体情况。如果你的批量大小是固定的,那么你可能会发现l.sum().backward()和l.mean().backward()在实践中没有太大的区别。如果你的批量大小可能变化,那么你可能会发现l.mean().backward()在实践中更稳定

补:总训练函数:

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
    """Train a model (defined in Chapter 3)."""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2264846.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UML图【重要】

文章目录 2.1 类图概述2.2 类图的作用2.3 类图表示法2.3.1 类的表示方式2.3.2 类与类之间关系的表示方式2.3.2.1 关联关系2.3.2.2 聚合关系2.3.2.3 组合关系2.3.2.4 依赖关系2.3.2.5 继承关系2.3.2.6 实现关系 统一建模语言&#xff08;Unified Modeling Language&#xff0c;U…

Flask中@app.route()的methods参数详解

诸神缄默不语-个人CSDN博文目录 在 Flask 中&#xff0c;app.route 是用于定义路由的核心装饰器&#xff0c;开发者可以通过它为应用指定 URL 映射及相应的处理函数。在处理 HTTP 请求时&#xff0c;不同的业务场景需要支持不同的 HTTP 方法&#xff0c;而 app.route 的 metho…

JavaSE---String(含一些源码)

&#xff08;一&#xff09;字符串构造 我们如何创建一个String类型的对象&#xff1f;有三种&#xff1a; String s1new String("hello"); //直接new一个String对象String s2"hello"; //使用常量串构造final char[] chars {h,e,l,l,o}; Strin…

0.96寸OLED显示屏详解

我们之前讲了 LCD1602&#xff0c;今天我们将它的进阶模块——OLED。它接线更少&#xff0c;性能更强&#xff0c;也能显示中文和图像了。 大家在学习单片机的时候是否会遇到调试的问题呢&#xff1f;例如 “这串代码我到底运行成功了没有” &#xff0c;我相信很多刚开始学习…

用un-app写的动漫风格的登录界面

动漫风格的的登录、注册界面模板&#xff0c;使用uni-app编写&#xff0c;直接复制粘贴即可。 废话不多说&#xff0c;代码如下&#xff1a; login.vue文件 <template><view class"content"><view class"tab-box"><text class"c…

Pytorch | 从零构建ParNet/Non-Deep Networks对CIFAR10进行分类

Pytorch | 从零构建ParNet/Non-Deep Networks对CIFAR10进行分类 CIFAR10数据集ParNet架构特点优势应用 ParNet结构代码详解结构代码代码详解SSEParNetBlock 类DownsamplingBlock 类FusionBlock 类ParNet 类 训练过程和测试结果代码汇总parnet.pytrain.pytest.py 前面文章我们构…

【服务器】linux服务器管理员查看用户使用内存情况

【服务器】linux服务器管理员查看用户使用硬盘内存情况 1、查看所有硬盘内存使用情况 df -h2、查看硬盘挂载目录下所有用户内存使用情况 du -sh /public/*3、查看某个用户所有文件夹占用硬盘内存情况 du -sh /public/zhangsan/*

[搜广推]王树森推荐系统——其他召回通道

地理位置召回 GeoHash召回 想法&#xff1a;用户可能对附近发生的事感兴趣 方法&#xff1a;对经纬度的编码&#xff0c;地图上一个长方形区域 索引&#xff1a;GeoHash -> 优质笔记列表(按时间倒排) 这条召回通道没有个性化 同城召回 想法&#xff1a;用户可能对同…

重温设计模式--外观模式

文章目录 外观模式&#xff08;Facade Pattern&#xff09;概述定义 外观模式UML图作用 外观模式的结构C 代码示例1C代码示例2总结 外观模式&#xff08;Facade Pattern&#xff09;概述 定义 外观模式是一种结构型设计模式&#xff0c;它为子系统中的一组接口提供了一个统一…

OpenCV学习——图像融合

import cv2 as cv import cv2 as cvbg cv.imread("test_images/background.jpg", cv.IMREAD_COLOR) fg cv.imread("test_images/forground.png", cv.IMREAD_COLOR)# 打印图片尺寸 print(bg.shape) print(fg.shape)resize_size (1200, 800)bg cv.resize…

ECharts热力图-笛卡尔坐标系上的热力图,附视频讲解与代码下载

引言&#xff1a; 热力图&#xff08;Heatmap&#xff09;是一种数据可视化技术&#xff0c;它通过颜色的深浅变化来表示数据在不同区域的分布密集程度。在二维平面上&#xff0c;热力图将数据值映射为颜色&#xff0c;通常颜色越深表示数据值越大&#xff0c;颜色越浅表示数…

进程间关系与守护进程

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 进程间关系与守护进程 收录于专栏[Linux学习] 本专栏旨在分享学习Linux的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 目录 1. 进程组 什…

LightGBM分类算法在医疗数据挖掘中的深度探索与应用创新(上)

一、引言 1.1 医疗数据挖掘的重要性与挑战 在当今数字化医疗时代,医疗数据呈爆炸式增长,这些数据蕴含着丰富的信息,对医疗决策具有极为重要的意义。通过对医疗数据的深入挖掘,可以发现潜在的疾病模式、治疗效果关联以及患者的健康风险因素,从而为精准医疗、个性化治疗方…

【文档搜索引擎】缓冲区优化和索引模块小结

开机之后&#xff0c;首次制作索引会非常慢&#xff0c;但后面就会快了 重启机器&#xff0c;第一次制作又会非常慢 这是为什么呢&#xff1f; 在 parserContent 里面&#xff0c;我们进行了一个读文件的操作 计算机读取文件&#xff0c;是一个开销比较大的操作&#xff0c; …

html+css网页设计 旅游 移动端 雪花旅行社4个页面

htmlcss网页设计 旅游 移动端 雪花旅行社4个页面 网页作品代码简单&#xff0c;可使用任意HTML辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作&#xff09;。 获取源码 …

3 JDK 常见的包和BIO,NIO,AIO

JDK常见的包 java.lang:系统基础类 java.io:文件操作相关类&#xff0c;比如文件操作 java.nio:为了完善io包中的功能&#xff0c;提高io性能而写的一个新包 java.net:网络相关的包 java.util:java辅助类&#xff0c;特别是集合类 java.sql:数据库操作类 IO流 按照流的流向分…

从零创建一个 Django 项目

1. 准备环境 在开始之前&#xff0c;确保你的开发环境满足以下要求&#xff1a; 安装了 Python (推荐 3.8 或更高版本)。安装 pip 包管理工具。如果要使用 MySQL 或 PostgreSQL&#xff0c;确保对应的数据库已安装。 创建虚拟环境 在项目目录中创建并激活虚拟环境&#xff…

ubuntu20.04安装imwheel实现鼠标滚轮调速

ubuntu20.04安装imwheel实现鼠标滚轮调速 Ubuntu 系统自带的设置中仅具备调节鼠标速度的功能&#xff0c;而无调节鼠标滚轮速度的功能。其默认的鼠标滚轮速度较为缓慢&#xff0c;在查看文档时影响尚可接受&#xff0c;但在快速浏览网页时&#xff0c;滚轮速度过慢会给用户带来…

GitLab的安装与卸载

目录 GitLab安装 GitLab使用 使用前可选操作 修改web端口 修改Prometheus端口 使用方法 GitLab的卸载 环境说明 系统版本 CentOS 7.2 x86_64 软件版本 gitlab-ce-10.8.4 GitLab安装 Gitlab的rpm包集成了它需要的软件&#xff0c;简化了安装步骤&#xff0c;所以直接…

简单工厂模式和策略模式的异同

文章目录 简单工厂模式和策略模式的异同相同点&#xff1a;不同点&#xff1a;目的&#xff1a;结构&#xff1a; C 代码示例简单工厂模式示例&#xff08;以创建图形对象为例&#xff09;策略模式示例&#xff08;以计算价格折扣策略为例&#xff09;UML区别 简单工厂模式和策…