b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器

news2025/2/28 3:12:33

文章目录

  • 一、损失函数
    • 1.简要介绍
    • 2.代码
  • 二、优化器
    • 1.简要介绍
    • 2.代码

一、损失函数

1.简要介绍

可参考博客:

常见的损失函数总结

损失函数的全面介绍

pytorch学习之十九种损失函数

损失函数(Loss Function)是用来衡量模型预测输出与实际标签之间的差异或误差程度的函数。在深度学习中,损失函数通常被设计为一个标量值,表示模型的预测值与真实标签之间的差异。

损失函数的选择对于训练深度学习模型非常重要,因为它直接影响着模型的训练效果和性能。在训练过程中,通过最小化损失函数来调整模型参数,使模型的预测结果逐渐接近真实标签,从而提高模型的准确性。

常见的损失函数:

均方误差(Mean Squared Error,MSE):用于回归任务,计算预测值与真实值之间的平方差的均值。

交叉熵损失函数(Cross Entropy Loss):用于分类任务,衡量模型输出的概率分布与真实标签的差异。

对数损失函数(Log Loss):也常用于二分类或多分类问题,衡量模型输出类别的概率与真实标签之间的关系。

Hinge损失函数:通常用于支持向量机(SVM)中,用于处理二分类问题。

Kullback-Leibler 散度(KL 散度):用于衡量两个概率分布之间的相似度。

2.代码

import torch
from torch import nn

# 定义输入张量和目标张量
inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)

# 对输入和目标张量进行reshape操作以匹配损失函数的输入要求
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))

# 实例化 L1 损失函数
loss = nn.L1Loss()
# 计算 L1 损失值
result = loss(inputs, targets)
print(result)

# 实例化均方误差(MSE)损失函数
loss_mse = nn.MSELoss()
# 计算均方误差损失值
result2 = loss_mse(inputs, targets)
print(result2)

代码运行结果:

在这里插入图片描述

二、优化器

1.简要介绍

优化器是深度学习中用于更新模型参数以最小化损失函数的算法。在神经网络训练过程中,通过计算损失函数对模型参数的梯度,优化器根据这些梯度来更新模型参数,使得损失函数逐渐减小,从而使模型更好地拟合训练数据。

2.代码

import torch.utils.data
import torchvision.datasets
from torch import nn
import torchvision
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

# 加载 CIFAR-10 数据集
datasets = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=torchvision.transforms.ToTensor(), download=True)

# 创建数据加载器
dataloader = DataLoader(datasets, batch_size=1)

# 定义神经网络模型 Guodong
class Guodong(nn.Module):
    def __init__(self):
        super(Guodong, self).__init__()
        self.module1 = Sequential(
            Conv2d(3, 32, 5, padding=2),  # 输入通道数为3,输出通道数为32,卷积核大小为5,填充为2
            MaxPool2d(2),  # 最大池化层,核大小为2
            Conv2d(32, 32, 5, padding=2),  # 输入通道数为32,输出通道数为32,卷积核大小为5,填充为2
            MaxPool2d(2),  # 最大池化层,核大小为2
            Conv2d(32, 64, 5, padding=2),  # 输入通道数为32,输出通道数为64,卷积核大小为5,填充为2
            MaxPool2d(2),  # 最大池化层,核大小为2
            Flatten(),  # 将多维输入展平为一维
            Linear(1024, 64),  # 全连接层,输入维度为1024,输出维度为64
            Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10
        )

    def forward(self, input):
        output = self.module1(input)
        return output

# 实例化 Guodong 模型
guodong = Guodong()

# 定义交叉熵损失函数
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(guodong.parameters(), lr=0.01)
for epoch in range(20):
    loss_sum = 0.0
    # 遍历数据加载器中的数据
    for data in dataloader:
        imgs, target = data
        # 将图片输入模型得到预测输出
        outputs = guodong(imgs)
        # 计算交叉熵损失值
        result_loss = loss(outputs, target)
        optim.zero_grad()
        # 反向传播计算梯度
        result_loss.backward()
        optim.step()
        loss_sum += result_loss
    print(loss_sum)


optim.zero_grad()
result_loss.backward()
optim.step()
这三处设置断点,调试,可以看到grad一开始是None,后来有了具体的数值

在这里插入图片描述
在这里插入图片描述
代码打印结果为:

在这里插入图片描述
(后面还没打印出来,程序运行有点慢QAQ)

可以看到最开始的时候loss_sum在变小,后来又变大。

在深度学习训练过程中,损失函数的值不一定是单调递减的,特别是在使用随机梯度下降(SGD)等基于随机采样的优化算法时。因此,损失函数值的变化可能会出现波动或不规则的情况。

sum_loss 的数值一开始是在减小的,但后来又增大了。这可能是由多种原因引起的,例如:

(1)训练数据的顺序:在每个 epoch 中,数据加载器可能以不同的顺序提供训练样本,这会导致模型参数的更新方向有所不同,从而影响损失函数的变化。

(2)学习率的设置:学习率控制着参数更新的步长大小,如果学习率设置得过大,可能会导致参数更新过程不稳定,损失函数值出现震荡或上升。

(3)模型复杂度和数据集的匹配程度:如果模型的复杂度过高,而训练数据集较小或难以拟合,模型可能会出现过拟合现象,导致损失函数值增大。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1504493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开发指南002-前后端信息交互规范-概述

前后端之间采用restful接口,服务和服务之间使用feign。信息交互遵循如下平台规范: 前端: 建立api目录,按照业务区分建立不同的.js文件,封装对后台的调用操作。其中qlm*.js为平台预制的接口文件,以qlm_user.…

离线数仓(五)【数据仓库建模】

前言 今天开始正式数据仓库的内容了, 前面我们把生产数据 , 数据上传到 HDFS , Kafka 的通道都已经搭建完毕了, 数据也就正式进入数据仓库了, 解下来的数仓建模是重中之重 , 是将来吃饭的家伙 ! 以及 Hive SQL 必须熟练到像喝水一样 ! 第1章 数据仓库概述 1.1 数据仓库概念 数…

【stm32 外部中断】

中断:在主程序运行过程中,出现了特定的中断触发条件(中断源),使得CPU暂停当前正在运行的程序,转而去处理中断程序,处理完成后又返回原来被暂停的位置继续运行 中断优先级:当有多个中…

mybatis-plus整合spring boot极速入门

使用mybatis-plus整合spring boot,接下来我来操作一番。 一,创建spring boot工程 勾选下面的选项 紧接着,还有springboot和依赖我们需要选。 这样我们就创建好了我们的spring boot,项目。 简化目录结构: 我们发现&a…

未来城市:探索数字孪生在智慧城市中的实际应用与价值

目录 一、引言 二、数字孪生与智慧城市的融合 三、数字孪生在智慧城市中的实际应用 1、智慧交通管理 2、智慧能源管理 3、智慧建筑管理 4、智慧城市管理 四、数字孪生在智慧城市中的价值 五、挑战与展望 六、结论 一、引言 随着科技的飞速发展,智慧城市已…

R统计学2 - 数据分析入门问题21-40

往期R统计学文章: R统计学1 - 基础操作入门问题1-20 21. 如何对矩阵按行 (列) 作计算? 使用函数 apply() vec 1:20 # 转换为矩阵 mat matrix (vec , ncol4) # [,1] [,2] [,3] [,4] # [1,] 1 6 11 16 # [2,] 2 7 12 17 # [3,] …

前端框架的发展历史介绍

前端框架的发展历史是Web技术进步的一个重要方面。从最初的简单HTML页面到现在的复杂单页应用程序(SPA),前端框架和库的发展极大地推动了Web应用程序的构建方式。以下是一些关键的前端框架和库,以及它们的发布年份、创建者和主要特…

UnicodeDecodeError: ‘gbk‘和Error: Command ‘pip install ‘pycocotools>=2.0

今天重新弄YOLOv5的时候发现不能用了,刚开始给我报这个错误 subprocess.CalledProcessError: Command ‘pip install ‘pycocotools>2.0‘‘ returned non-zero exit statu 说这个包安装不了 根据他的指令pip install ‘pycocotools>2.0这个根…

从零开始:神经网络(2)——MP模型

声明:本文章是根据网上资料,加上自己整理和理解而成,仅为记录自己学习的点点滴滴。可能有错误,欢迎大家指正。 神经元相关知识,详见从零开始:神经网络——神经元和梯度下降-CSDN博客 1、什么是M-P 模型 人…

CorelDRAW Graphics Suite2024专业图形设计软件Windows/Mac最新25.0.0.230版

CorelDRAW Graphics Suite 2024是一款专业的图形设计软件,它集成了CorelDRAW Standard 2024和其他高级图形处理工具,为用户提供了全面的图形设计和编辑解决方案。 该软件拥有强大的矢量编辑功能,用户可以轻松创建和编辑矢量图形,…

数字化转型导师坚鹏:科技金融政策、案例及数字化营销

科技金融政策、案例及数字化营销 课程背景: 很多银行存在以下问题: 不清楚科技金融有哪些利好政策? 不知道科技金融有哪些成功案例? 不知道科技金融如何数字化营销? 课程特色: 以案例的方式解读原…

聚类简单讲解

聚类任务 聚类任务是指将一组数据分成多个不同的组(或簇),使得同一组内的数据点彼此相似,而不同组之间的数据点尽可能不相似的过程。聚类任务的目标是发现数据中的固有结构,而不需要事先知道数据的类别信息。聚类算法…

IntelliJ IDEA Dev 容器

​一、dev 容器 开发容器(dev 容器)是一个 Docker 容器,配置为用作功能齐全的开发环境。 IntelliJ IDEA 允许您使用此类容器来编辑、构建和运行您的项目。 IntelliJ IDEA 还支持多个容器连接,这些连接可以使用 Docker Compose …

多种方法求解数组排序

𝙉𝙞𝙘𝙚!!👏🏻‧✧̣̥̇‧✦👏🏻‧✧̣̥̇‧✦ 👏🏻‧✧̣̥̇:Solitary_walk ⸝⋆ ━━━┓ - 个性标签 - :来于“云”的“羽球人”。…

Day29:安全开发-JS应用DOM树加密编码库断点调试逆向分析元素属性操作

目录 JS原生开发-DOM树-用户交互 JS导入库开发-编码加密-逆向调试 思维导图 JS知识点: 功能:登录验证,文件操作,SQL操作,云应用接入,框架开发,打包器使用等 技术:原生开发&#x…

HTML 学习笔记——标签创建小技巧

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Emmrt工具</title></head><body>&…

超越Chain-of-Thought LLM 推理

原文地址&#xff1a;Beyond Chain-of-Thought LLM Reasoning 2024 年 2 月 13 日 介绍 最近的一项研究解决了需要增强大型语言模型 (LLM) 的推理能力&#xff0c;超越直接推理 (Direct Reasoning&#xff0c;DR) 框架&#xff0c;例如思想链和自我一致性&#xff0c;这些框架可…

ARM/Linux嵌入式面经(一):海康威视

海康威视 1.函数指针和指针函数区别 1.定义的差异 函数指针&#xff1a;函数指针的定义涉及到函数的地址。例如&#xff0c;定义一个指向函数的指针 int (*fp)(int)&#xff0c;这里 fp 是一个指针&#xff0c;它指向一个接受一个整数参数并返回整数的函数。 指针函数&#…

了解华为(PVID VLAN)与思科的(Native VLAN)本征VLAN的区别并学习思科网络中二层交换机的三层结构局域网VLAN配置

一、什么是二层交换机&#xff1f; 二层交换机&#xff08;Layer 2 Switch&#xff09;是一种网络设备&#xff0c;主要工作在OSI模型的数据链路层&#xff08;第二层&#xff09;&#xff0c;用于在局域网内部进行数据包的交换和转发。二层交换机通过学习MAC地址表&#xff0…

Excel F4键的作用

目录 一. 单元格相对/绝对引用转换二. 重复上一步操作 一. 单元格相对/绝对引用转换 ⏹ 使用F4键 如下图所示&#xff0c;B1单元格引用了A1单元格的内容。此时是使用相对引用&#xff0c;可以按下键盘上的F4键进行相对引用和绝对引用的转换。 二. 重复上一步操作 ⏹添加或删除…