[pytorch] --- pytorch基础之损失函数与反向传播

news2024/9/21 4:28:19

1 损失函数

1.1 Loss Function的作用

  • 每次训练神经网络的时候都会有一个目标,也会有一个输出。目标和输出之间的误差,就是用Loss Function来衡量的。所以Loss误差是越小越好的。
  • 此外,我们可以根据误差Loss,指导输出output接近目标target。即我们可以以Loss为依据,不断训练神经网络,优化神经网络中各个模块,从而优化output 。

Loss Function的作用:
(1)计算实际输出和目标之间的差距
(2)为我们更新输出提供一定的依据,这个提供依据的过程也叫反向传播。

我们可以看下pytorch为我们提供的损失函数:https://pytorch.org/docs/stable/nn.html#loss-functions

1.2 损失函数简单示例

以L1Loss损失函数为例子,他其实很简单,就是把实际值与目标值,挨个相减,再求个均值。就是结果。(这个结果就反映了实际值的好坏程度,这个结果越小,说明越靠近目标值)
在这里插入图片描述
示例代码

import torch
from torch.nn import L1Loss

inputs = torch.tensor([1,2,3],dtype=torch.float32) # 实际值
targets = torch.tensor([1,2,5],dtype=torch.float32) # 目标值
loss = L1Loss()
result = loss(inputs,targets)
print(result)

输出结果:tensor(0.6667)
接下来我们看下两个常用的损失函数:均方差和交叉熵误差

1.3 均方差

均方差:实际值与目标值对应做差,再平方,再求和,再求均值。
那么套用刚才的例子就是:(0+0+2^2)/3=4/3=1.33333…

代码实现

import torch
from torch.nn import L1Loss, MSELoss

inputs = torch.tensor([1,2,3],dtype=torch.float32) # 实际值
targets = torch.tensor([1,2,5],dtype=torch.float32) # 目标值
loss_mse = MSELoss()

result = loss_mse(inputs,targets)
print(result)

输出结果:tensor(1.3333)

1.4 交叉熵误差:

这个比较复杂一点,首先我们看官方文档给出的公式
先放一个别人的解释:https://www.jianshu.com/p/6049dbc1b73f
这里先用代码实现一下他的简单用法:

import torch
from torch.nn import L1Loss, MSELoss, CrossEntropyLoss

x = torch.tensor([0.1,0.2,0.3]) # 预测出三个类别的概率值
y = torch.tensor([1]) # 目标值  应该是这三类中的第二类 也就是下标为1(从0开始的)
x = torch.reshape(x,(1,3)) # 修改格式  交叉熵函数的要求格式是 (N,C) N是bitch_size C是类别
# print(x.shape)
loss_cross = CrossEntropyLoss()
result = loss_cross(x,y)
print(result)

输出结果:tensor(1.1019)

1.5 如何在神经网络中用到Loss Function

# -*- coding: utf-8 -*-
# 作者:小土堆
# 公众号:土堆碎念
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    result_loss = loss(outputs, targets)
    print(result_loss)

2 反向传播

所谓的反向传播,就是利用我们得到的loss值,来对我们神经网络中的一些参数做调整,以达到loss值降低的目的。(图片经过一层一层网络的处理,最终得到结果,这是正向传播。最终结果与期望值运算得到loss,用loss反过来调整参数,叫做反向传播。个人理解,不一定严谨!)

2.1 backward

这里利用loss来调整参数,主要使用的方法是梯度下降法。
这个方法原理其实还是有点复杂的,但是pytorch为我们实现好了,所以用起来很简单。
调用损失函数得到的值的backward函数即可。

loss = CrossEntropyLoss() # 定义loss函数
# 实例化这个网络
test = Network()
for data in dataloader:
    imgs,targets = data
    outputs = test(imgs) # 输入图片
    result_loss = loss(outputs,targets)
    result_loss.backward() # 反向传播
    print('ok')

打断点调试,可以看到,grad属性被赋予了一些值。如果不用反向传播,是没有值的
当然,计算出这个grad值只是梯度下降法的第一步,算出了梯度,如何下降呢,要靠优化器

2.2 optimizer

优化器也有好几种,官网对优化器的介绍:https://pytorch.org/docs/stable/optim.html
不同的优化器需要设置的参数不同,但是有两个是大部分都有的:模型参数与学习速率
我们以SDG优化器为例,看下用法:

# 实例化这个网络
test = Network()
loss = CrossEntropyLoss() # 定义loss函数
# 构造优化器
# 这里我们选择的优化器是SGD 传入两个参数 第一个是个模型test的参数 第二个是学习率
optim = torch.optim.SGD(test.parameters(),lr=0.01)

for data in dataloader:
    imgs,targets = data
    outputs = test(imgs) # 输入图片
    result_loss = loss(outputs,targets) # 计算loss
    optim.zero_grad() #因为这是在循环里面 所以每次开始优化之前要把梯度置为0 防止上一次的结果影响这一次
    result_loss.backward() # 反向传播 求得梯度
    optim.step() # 对参数进行调优

这里面我们刚学得主要是这三行:
清零,反向传播求梯度,调优

optim.zero_grad() #因为这是在循环里面 所以每次开始优化之前要把梯度置为0 防止上一次的结果影响这一次
result_loss.backward() # 反向传播 求得梯度
optim.step() # 对参数进行调优

我们可以打印一下loss,看下调优后得loss有什么变化。
注意:我们dataloader是把数据拿出来一遍,那么看了一遍之后,经过这一遍的调整,下一遍再看的时候,loss才有变化。
所以,我们先让让他学习20轮,然后看一下每一轮的loss是多少

# 实例化这个网络
test = Network()
loss = CrossEntropyLoss() # 定义loss函数
# 构造优化器
# 这里我们选择的优化器是SGD 传入两个参数 第一个是个模型test的参数 第二个是学习率
optim = torch.optim.SGD(test.parameters(),lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs,targets = data
        outputs = test(imgs) # 输入图片
        result_loss = loss(outputs,targets) # 计算loss
        optim.zero_grad() #因为这是在循环里面 所以每次开始优化之前要把梯度置为0 防止上一次的结果影响这一次
        result_loss.backward() # 反向传播 求得梯度
        optim.step() # 对参数进行调优
        running_loss = running_loss + result_loss # 记录下这一轮中每个loss的值之和
    print(running_loss) # 打印每一轮的loss值之和

可以看到,loss之和一次比一次降低了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2097784.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浏览器百科:网页存储篇-Cookie详解(一)

1.引言 在现代网页开发中,数据存储和管理是提升用户体验的重要环节之一。作为网页存储技术的元老,Cookie 自从诞生以来就扮演着不可或缺的角色。Cookie 允许网站在用户浏览器中存储小块数据,从而实现状态保持、用户跟踪以及个性化设置等功能…

数仓基础(六):离线与实时数仓区别和建设思路

文章目录 离线与实时数仓区别和建设思路 一、离线数仓与实时数仓区别 二、实时数仓建设思路 离线与实时数仓区别和建设思路 ​​​​​​​一、离线数仓与实时数仓区别 离线数据与实时数仓区别如下: 对比方面 离线数仓 实时数仓 架构选择 传统大数据架构 …

KRaft模式下的Kafka启动指南:摆脱Zookeeper依赖

一、背景介绍 多年来,人们一直在同时使用Apache ZooKeeper和Apache Kafka。但是自Apache Kafka 3.3发布以来,它就可以在没有ZooKeeper的情况下运行。同时它包含了新的命令kafka-metadata-quorum和kafka-metadata-shell?该如何安装新版kafka&#xff0c…

快手小店多店铺管理神器:甜羊浏览器

随着短视频平台的兴起,快手小店已经成为众多商家的重要销售渠道。然而,对于同时管理多个快手小店的商家来说,如何高效地运营这些店铺成为了一大挑战。特别是在需要同时登录和管理多个店铺账号时,问题尤为突出。那么,如…

【Python报错已解决】“ImportError: cannot import name ‘triu‘ from ‘scipy.linalg‘“?

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 文章目录 引言:一、问题描述1.1 报错示例:以下代码尝试从 scipy.linalg 中导入 triu 函数。1.2 报错分析…

@JsonFormat失败问题处理

JsonFormat失败问题处理 在开发中经常使用到时间格式,如果数据库的时间是timestamp格式的,则返回的格式通过带有毫秒 例如2024-08-30 14:53:58.236 这样子的格式,通常不是我们想要的; 但是我们又不想再后端写更多的代码&#xff…

公司电脑的敏感文件怎么审查?七大敏感文件管控策略,高效应对企业泄密风险!

在数字化时代,企业的敏感文件如同珍贵的宝藏,需时刻警惕潜在的风险。 古有"城门失火,殃及池鱼"之警,今有企业敏感信息泄露,牵一发而动全身之虞。 因此,如何有效审查与管理公司电脑中的敏感文件…

将.xml格式转换为YOLO所需的.txt文件格式

首先,原始的.xml数据集基础构成如下: image目录结构如下: label目录结构如下: .xml内容如下: 之后修改代码如下: import xml.etree.ElementTree as ET import os, cv2 import numpy as np from os import…

机器学习(西瓜书)第 3 章 线性模型

3.1 基本形式 例如若在西瓜 问题中学得“/好瓜⑺- 0.2 • n色泽 0.5 •/根蒂 0.3 •力敲声 1”,则意味着可 通过综合考虑色泽、根蒂和敲声来判断瓜好不好,其中根蒂最要紧,而敲声比 色泽更重要. 本章介绍几种经典的线性模型.我们先从回归任务…

为什么正午选她演大女主戏?看到殷桃这个片段,我全懂了

最近小编听说正午的最新力作《凡人歌》要上了,而且女主还是我特别喜欢的殷桃,赶紧马不停蹄的去追剧,结果狠狠爱上了殷桃的演技! 剧里殷桃饰演的沈琳是一位家庭主妇,她以为她放弃了工作,做家庭主妇&#xff…

你还在为编程效率低下而烦恼吗?编程界的神级辅助!一键解锁高效编程模式,让你的工作效率飙升不止一倍!

哪个编程工具让你的工作效率翻倍? 第一章 引言 在软件开发领域,编程工具的重要性不言而喻。它们不仅能够加速开发过程,还能提高代码质量,从而显著提升开发人员的工作效率。随着技术的不断进步,越来越多的编程工具涌现…

多头切片的关键:Model 类 call解释;LlamaModel 类 call解释;多头切片的关键:cache的数据拼接

目录 Model 类 call解释 LlamaModel 类 call解释 方法签名 方法体 总结 Model 类 call解释 这段代码定义了一个特殊的方法 __call__,它是Python中的一个魔术方法(magic method),允许类的实例像函数那样被调用。在这个上下文中,这个方法很可能被定义在一个封装了某种…

java宠物商城网站系统的设计与实现

springboot508基于Springboot宠物商城网站系统 题目:宠物商城网站系统的设计与实现 摘 要 如今社会上各行各业,都喜欢用自己行业的专属软件工作,互联网发展到这个时候,人们已经发现离不开了互联网。新技术的产生,往往…

算法图解(1)

配套代码: https://github.com/egonSchiele/grokking_algorithms?tabreadme-ov-filehttps://github.com/egonSchiele/grokking_algorithms?tabreadme-ov-file 理论 数据结构:组织和存储数据的方式,影响程序的性能和存储效率 算法&#…

OpenHarmony使用ArkUI Inspector分析布局

● 摘要:视图的嵌套层次会影响应用的性能,开发者应该移除多余的嵌套层次,缩短组件刷新耗时。本文会介绍如何使用ArkUI Inspector工具分析布局,提示应用响应性能。 ● 关键字:列举本文相关的关键字:OpenHar…

Linux防火墙问题排查记录

问题描述 在业务当中,开通了防火墙,导致外部数据无法通过SFTP服务访问本机的服务,根据防火墙策略判断,应该是有一些IP没有被加进accept策略导致的,所以需要查看防火墙日志来追溯哪些IP被过滤掉了,只要放通…

开学了 需要考研的同学快看过来!考研倒计时你的鞭策神器!

开学了 需要考研的同学快看过来!考研倒计时你的鞭策神器! 2025年考研准备又要开始了,考试科目包括思想政治理论、管理类联考综合能力、外国语、业务课一、业务课二等。考研初试是每位考生迈向研究生阶段的重要一步,希望考生在考试…

python+requests 搭建接口自动化测试框架【超详细】

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、前言 Python是一种简单易学、功能强大的编程语言,广泛应用于各种软件开发和测试场景中。requests是Python中流行的HTTP库,支持发送…

windows下安装elasticSearch和kibana

下载es 下载地址官网 下载后是个压缩包(elasticsearch-8.15.0-windows-x86_64),解压即可 启动 配置 改一下 /conf/jvm.options文件,最后加一行编码配置,这个是为了启动后防止控制台乱码 -Dfile.encodingGBK启动es 依赖jdk8环境&#xf…

Qt中的父窗口子窗口和父类子类的区别

好多人在开发初期,往往将父子窗口和父子类搅在一起容易搞混。 今天借着这篇文章给大家分辨一下。 C中我们常说子类继承自父类,子类具有父类所有的特性和功能。所以父类和子类是继承关系。 而子窗体和父窗体,不是继承关系,准确地讲…