4_损失函数和优化器

news2025/1/21 6:01:14

 教学视频:损失函数与反向传播_哔哩哔哩_bilibili

损失函数(Loss Function)

损失函数是衡量模型预测输出与实际目标之间差距的函数。在监督学习任务中,我们通常希望模型的预测尽可能接近真实的目标值。损失函数就是用来量化模型预测的误差大小的一种方法。

作用:

  1. 衡量模型性能: 损失函数的值越小,表示模型在训练集上的预测结果与实际标签越接近,即模型的性能越好。
  2. 指导模型优化(反向传播): 通过最小化损失函数来调整模型的参数,使得模型能够更准确地预测目标值。优化过程就是通过调整模型参数来减小损失函数的过程。可通过这个过程得到梯度。

常见的损失函数

常见的损失函数包括均方误差(Mean Squared Error, MSE)、交叉熵损失(Cross Entropy Loss)、对数损失(Log Loss)等,具体选择哪种损失函数取决于问题的类型和输出的形式。

pytorch官方网址:torch.nn — PyTorch 2.4 documentation

 

 接下来看几个损失函数:

L1Loss 

torch.nn.L1Loss(size_average=Nonereduce=Nonereduction='mean')

 如图所示,是这个函数所使用的的公式。看起来很复杂,其实就是对所有的损失求平均或求和。

另外我们需要注意输入和输出格式:

 使用案例:

inputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)

inputs=torch.reshape(inputs,(1,1,1,3))
targets=torch.reshape(targets,(1,1,1,3))

#没有指定参数reduction,则默认按均值方式计算,除此之外,还可以使用reduction='sum',使其求和
loss=nn.L1Loss()
#计算公式:(|1-1|+|2-2|+|5-3|)/3≈0.667
result=loss(inputs,targets)
print(result)

 MSELoss

平方差损失函数。

torch.nn.MSELoss(size_average=Nonereduce=Nonereduction='mean')

与上面的使用方法基本相同。

loss=nn.MSELoss()
result=loss(inputs,targets)
print(result)
#输出结果:tensor(1.3333)

 CrossEntropyLoss

交叉熵损失函数,特别适用于多分类任务和输出为概率分布的情况。这个我首次接触是在机器学习中的逻辑回归那部分。如果想要了解更多可以看我另一篇文章:0_(机器学习)逻辑回归介绍-CSDN博客

torch.nn.CrossEntropyLoss(weight=Nonesize_average=Noneignore_index=-100

reduce=Nonereduction='mean'label_smoothing=0.0)

注意此函数的输入和输出与前面两种略有不同:

x=torch.tensor([0.1,0.2,0.3])
y=torch.tensor([1])
x=torch.reshape(x,(1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(result_cross)

 输出结果为:

tensor(1.1019)

反向传播(Backpropagation)

反向传播是一种有效的训练神经网络的方法,它利用链式法则计算损失函数对每个模型参数的梯度,并根据梯度更新参数。它是损失函数优化过程中的关键步骤。

作用:

  1. 计算梯度: 反向传播算法通过将损失函数的梯度从网络的输出层向输入层传播,计算每个参数对损失函数的影响程度。
  2. 参数更新: 计算得到的梯度可以用来更新模型的参数,使得损失函数值减小,从而提高模型的预测性能。

反向传播利用了链式法则来计算复杂的导数,高效地更新神经网络中的参数。这种算法使得深度学习模型可以在大量数据上进行训练,并从数据中学习到复杂的模式和关系。

测试代码:

from torch import nn
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("../dataset",train=False,download=True,transform=torchvision.transforms.ToTensor())
dataloader=DataLoader(dataset,batch_size=64,drop_last=True)

class MyNn(nn.Module):
    def __init__(self) :
        super().__init__()
        self.conv1=nn.Conv2d(3,32,5,padding=2)
        self.maxpool1=nn.MaxPool2d(2)
        self.conv2=nn.Conv2d(32,32,5,padding=2)
        self.maxpool2=nn.MaxPool2d(2)
        self.conv3=nn.Conv2d(32,64,5,padding=2)
        self.maxpool3=nn.MaxPool2d(2)
        self.flatten=nn.Flatten()
        self.linear1=nn.Linear(1024,64)
        self.linear2=nn.Linear(64,10)
        '''使用Sequential可以简化代码'''
        # self.model1=nn.Sequential(
        #     nn.Conv2d(3,32,5,padding=2),
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(32,32,5,padding=2),
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(32,64,5,padding=2),
        #     nn.MaxPool2d(2),
        #     nn.Flatten(),
        #     nn.Linear(1024,64),
        #     nn.Linear(64,10)
        # )


    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.maxpool2(x)
        x=self.conv3(x)
        x=self.maxpool3(x)
        x=self.flatten(x)
        x=self.linear1(x)
        x=self.linear2(x)
        # x=self.model1(x)
        return x

mynn=MyNn()
loss=nn.CrossEntropyLoss()
for data in dataloader:
    imgs,targets=data
    outputs=mynn(imgs)
    result_loss=loss(outputs,targets)
    result_loss.backward()
    print(result_loss)

在其中打上断点,如图:

然后debug运行查看参数,就可以看到在没有运行backward()函数之前,图中所指参数为空:

 那么现在往下运行一行试试:

 以上就是我们得到的梯度(梯度下降法中的梯度)。接下来选择合适的优化器,我们就可以对模型进行优化了。

优化器

官方文档:torch.optim — PyTorch 2.4 documentation

 用于优化神经网络模型的库,它的主要作用是实现各种优化算法,帮助模型在训练过程中更新参数以最小化损失函数。

有许许多多不同的优化器,入门阶段就不一一介绍。有需要可以自己查用法。

from torch import nn
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("../dataset",train=False,download=True,transform=torchvision.transforms.ToTensor())
dataloader=DataLoader(dataset,batch_size=64,drop_last=True)

class MyNn(nn.Module):
    def __init__(self) :
        super().__init__()
        self.conv1=nn.Conv2d(3,32,5,padding=2)
        self.maxpool1=nn.MaxPool2d(2)
        self.conv2=nn.Conv2d(32,32,5,padding=2)
        self.maxpool2=nn.MaxPool2d(2)
        self.conv3=nn.Conv2d(32,64,5,padding=2)
        self.maxpool3=nn.MaxPool2d(2)
        self.flatten=nn.Flatten()
        self.linear1=nn.Linear(1024,64)
        self.linear2=nn.Linear(64,10)
        '''使用Sequential可以简化代码'''
        # self.model1=nn.Sequential(
        #     nn.Conv2d(3,32,5,padding=2),
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(32,32,5,padding=2),
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(32,64,5,padding=2),
        #     nn.MaxPool2d(2),
        #     nn.Flatten(),
        #     nn.Linear(1024,64),
        #     nn.Linear(64,10)
        # )


    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.maxpool2(x)
        x=self.conv3(x)
        x=self.maxpool3(x)
        x=self.flatten(x)
        x=self.linear1(x)
        x=self.linear2(x)
        # x=self.model1(x)
        return x

mynn=MyNn()
loss=nn.CrossEntropyLoss()
optim=torch.optim.SGD(mynn.parameters(),lr=0.01)
for data in dataloader:
    imgs,targets=data
    outputs=mynn(imgs)
    result_loss=loss(outputs,targets)
    optim.zero_grad()
    result_loss.backward()
    optim.step()

优化输出结果,可以看到cost在逐渐变小:

tensor(358.2124, grad_fn=<AddBackward0>)
tensor(351.9559, grad_fn=<AddBackward0>)
tensor(328.9600, grad_fn=<AddBackward0>)
tensor(314.0949, grad_fn=<AddBackward0>)
tensor(306.4304, grad_fn=<AddBackward0>)
tensor(297.4454, grad_fn=<AddBackward0>)
tensor(288.6151, grad_fn=<AddBackward0>)
tensor(280.8827, grad_fn=<AddBackward0>)
tensor(273.8977, grad_fn=<AddBackward0>)
tensor(267.9252, grad_fn=<AddBackward0>)
tensor(262.4760, grad_fn=<AddBackward0>)
tensor(257.2650, grad_fn=<AddBackward0>)
tensor(252.2356, grad_fn=<AddBackward0>)
tensor(247.5495, grad_fn=<AddBackward0>)
tensor(243.2856, grad_fn=<AddBackward0>)
tensor(239.3630, grad_fn=<AddBackward0>)
tensor(235.6944, grad_fn=<AddBackward0>)
tensor(232.2151, grad_fn=<AddBackward0>)
tensor(228.8934, grad_fn=<AddBackward0>)
tensor(225.7084, grad_fn=<AddBackward0>)

截图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1978547.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

神经网络基础--激活函数

&#x1f579;️学习目标 &#x1f579;️什么是神经网络 1.神经网络概念 2.人工神经网络 &#x1f579;️网络非线性的因素 &#x1f579;️常见的激活函数 1.sigmoid激活函数 2.tanh激活函数 3.ReLU激活函数 4.softmax激活函数 &#x1f579;️总结 &#x1f57…

计算机基础(Windows 10+Office 2016)教程 —— 第5章 文档编辑软件Word 2016(上)

第5章 文档编辑软件Word 2016 5.1 Word 2016入门5.1.1 Word 2016 简介5.1.2 Word 2016 的启动5.1.3 Word 2016 的窗口组成5.1.4 Word 2016 的视图方式5.1.5 Word 2016 的文档操作5.1.6 Word 2016 的退出 5.2 Word 2016的文本编辑5.2.1 输入文本5.2.3 插入与删除文本5.2.4 复制与…

二进制与进制转换与原码、反码、补码详解--内含许多超详细图片讲解!!!

前言 今天给大家分享一下C语言操作符的详解&#xff0c;但在此之前先铺垫一下二进制和进制转换与原码、反码、补码的知识点&#xff0c;都非常详细&#xff0c;也希望这篇文章能对大家有所帮助&#xff0c;大家多多支持呀&#xff01; 操作符的内容我放在我的下一篇文章啦&am…

基于人工智能的口试模拟、LLM将彻底改变 STEM 教育

概述 STEM教育是一种整合科学&#xff08;Science&#xff09;、技术&#xff08;Technology&#xff09;、工程&#xff08;Engineering&#xff09;和数学&#xff08;Mathematics&#xff09;的教育方法。这种教育模式旨在通过跨学科的方式培养学生的创新能力、问题解决能力…

MySQL 高级 - 第十四章 | 事务基础知识

目录 第十四章 事务基础知识14.1 数据库事务概述14.1.1 存储引擎支持情况14.1.2 基本概念14.1.3 事务的 ACID 特性14.1.4 事务的状态 14.2 如何使用事务14.2.1 显示事务14.2.2 隐式事务14.2.3 隐式提交数据的情况14.2.4 使用举例14.2.4.1 提交与回滚14.2.4.2 测试不支持事务的 …

Yarn:一个快速、可靠且安全的JavaScript包管理工具

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;还请三连支持一波哇ヾ(&#xff20;^∇^&#xff20;)ノ&#xff09; 目录 一、Yarn简介 二、Yarn的安装 1. 使用npm安装Yarn 2. 在macOS上…

11.redis的客户端-Jedis

1.Jedis 以redis命令作为方法名称&#xff0c;学习成本低&#xff0c;简单使用。但是jedis实例是不安全的&#xff0c;多线程环境下需要基于连接池来使用。 2.Lettuce lettuce是基于Netty实现的&#xff0c;支持同步&#xff0c;异步和响应式编程方式&#xff0c;并且是线程…

EmEditor 打开文档后光标如何默认定位到文档最后一行?

1、录制宏 &#xff08;1&#xff09;、点击工具栏上的红色录制宏按钮&#xff0c;开始录制宏。如图&#xff1a; &#xff08;2&#xff09;、按住快捷键Ctrl End快捷键&#xff0c;使光标跳转到文档末尾 &#xff08;3&#xff09;、完成录制后&#xff0c;再次点击录制按钮…

Hive SQL ——窗口函数源码阅读

前言 使用Starrocks引擎中的窗口函数 row_number() over( )对10亿的数据集进行去重操作&#xff0c;BE内存溢出问题频发&#xff08;忘记当时指定的BE内存上限是多少了.....&#xff09;&#xff0c;此时才意识到&#xff0c;开窗操作&#xff0c;如果使用 不当&#xff0c;反而…

stm32工程配置

目录 STM32F103 start&#xff1a;启动文件、内核寄存器文件、外设寄存器文件、时钟配置文件 library&#xff1a;标准库函数&#xff08;内核及外设驱动&#xff09; user&#xff1a;用户文件、库函数配置文件、中断程序文件 添加宏定义 STM32F407 start目录 启动文件…

实战:使用Certbot签发免费ssl泛域名证书(主域名及其它子域名共用同一套证书)-2024.8.4(成功测试)

1、使用Certbot签发免费ssl泛域名证书 | One实战&#xff1a;使用Certbot签发免费ssl泛域名证书(主域名及其它子域名共用同一套证书)-2024.8.4(成功测试)https://wiki.onedayxyy.cn/docs/docs/Certbot-install/

Transformer相关介绍

1 Transformer 介绍 Transformer的本质上是一个Encoder-Decoder的结构。 1.1 编码器 在Transformer模型中&#xff0c;编码器&#xff08;Encoder&#xff09; 的主要作用是将输入序列&#xff08;例如文本、语音等&#xff09;转换为隐藏表示&#xff08;或者称为特征表示…

24军dui文职联勤保障部报名照规格要求

24军dui文职联勤保障部报名照规格要求 #军队文职 #文职 #文职备考 #联勤保障部队 #文职考试 #文职上岸 #2024军队文职

python-查找元素3(赛氪OJ)

[题目描述] 有n个不同的数&#xff0c;从小到大排成一列。现在告诉你其中的一个数x&#xff0c;x不一定是原先数列中的数。你需要输出最后一个<x的数在此数组中的下标。输入&#xff1a; 输入共两行第一行为两个整数n、x。第二行为n个整数&#xff0c;代表a[i]。输出&#x…

练习2.30

2.29题目没有理解,暂时没有做出来,先把2.30做了 上代码 (defn square [x](* x x)) ;第一版,直接定义 (defn square-tree[tree](cond (not (seq? tree)) (square tree)(empty? tree) nil:else (cons (square-tree (first tree)) (square-tree (rest tree)))) ) ;使用map …

LeetCode刷题笔记 | 283 | 移动零 | 双指针 |Java | 详细注释

&#x1f64b;大家好&#xff01;我是毛毛张! &#x1f308;个人首页&#xff1a; 神马都会亿点点的毛毛张 原地移除元素2 LeetCode链接&#xff1a;283. 移动零 1.题目描述 给定一个数组 nums&#xff0c;编写一个函数将所有 0 移动到数组的末尾&#xff0c;同时保持非零元…

Nextjs——国际化那些事儿

背景&#xff1a; 某一天&#xff0c;产品经理跟我说&#xff0c;我们的产品需要搞国际化 国际化的需求说白了就是把项目中的文案翻译成不同的语言&#xff0c;用户想用啥语言来浏览网页就用啥语言&#xff0c;虽然说英语是通用语言&#xff0c;但国际化了嘛&#xff0c;产品才…

学习编程的第二十天,加油!

3&#xff1a;递归与迭送&#xff08;循环是一种迭代&#xff09; &#xff01;&#xff01;&#xff01;递归算有些东西时计算量会很大导致运行时间过久&#xff0c;而使用循环会大大节省时间&#xff0c;但需要注意溢出的情况。 递归的练习&#xff0c;第一张呢不符合我们的…

刷题——不同路径的数目

不同路径的数目(一)_牛客题霸_牛客网 我第一眼&#xff0c;觉得是没有思路的&#xff0c;我也是看别人代码反应过来&#xff0c; 画图可以看出来 外边沿的只有一种到达方式&#xff0c;全部赋值1&#xff0c; 如果有两个方块相接&#xff0c;那就让此方块的左邻和右邻相加&…

线程池ThreadPoolExecutor使用

文章目录 一、基础-Java中线程创建的方式1.1、继承Thread类创建线程1.2、实现Runnable接口创建线程1.3、实现Calable接口创建线程1.4、使用线程池创建线程二、概念-线程池基本概念2.1、并发和井行的主要区别2.1.1、处理任务不同2.1.2、存在不同2.1.3、CPU资源不同2.2、什么是线…