【学习笔记】【Pytorch】12.损失函数与反向传播

news2025/1/11 8:40:15

【学习笔记】【Pytorch】12.损失函数与反向传播

  • 一、损失函数的介绍
    • 1.L1Loss类的使用
      • 代码实现
    • 2.MSELoss类的使用
    • 3.损失函数在模型中的实现
  • 二、反向传播

一、损失函数的介绍

参考
损失函数(loss function)
pytorch loss-functions 文档

作用
1.计算实际输出和目标之间的差距
2.为我们更新输出提供一定的依据(反向传播)

1.L1Loss类的使用

from torch.nn import L1Loss
class torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

参考
torch.nn.L1Loss

作用:创建一个衡量输入x(模型预测输出)和目标y之间差的绝对值的平均值的标准。

在这里插入图片描述
reduction = ‘mean’
在这里插入图片描述

代码实现

import torch
from torch.nn import L1Loss


input = torch.tensor([1, 2, 3], dtype=torch.float32)  # 创建一个一阶张量(3个元素)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)  # 创建一个一阶张量(3个元素)

# 转化为四阶张量(不转化也可以)
inputs = torch.reshape(input, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))

loss1 = L1Loss()  # 创建一个实例
loss2 = L1Loss(reduction="sum")  # 创建一个实例
loss3 = L1Loss(reduction="none")  # 创建一个实例
result1 = loss1(inputs, targets)
result2 = loss2(inputs, targets)
result3 = loss3(inputs, targets)

print(result1)
print(result2)
print(result3)

输出

tensor(0.6667)
tensor(2.)
tensor([[[[0., 0., 2.]]]])

2.MSELoss类的使用

from torch.nn import MSELoss
class torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

在这里插入图片描述
参考
torch.nn.MSELoss

作用:创建一个衡量输入x(模型预测输出)和目标y之间均方误差标准。

reduction = ‘mean’
在这里插入图片描述

# -*- coding: UTF-8 -*-
# 开发团队  :桂林电子科技大学 - 人工智能学院 - chen
# 开发人员  :Chen
# 开发时间  :2023/1/15 21:46
# 开发名称  :18.nn.loss.py
# 开发工具  :PyCharm
import torch
from torch.nn import MSELoss


input = torch.tensor([1, 2, 3], dtype=torch.float32)  # 创建一个一阶张量(3个元素)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)  # 创建一个一阶张量(3个元素)

# 转化为四阶张量(不转化也可以)
inputs = torch.reshape(input, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))

loss1 = MSELoss()  # 创建一个实例
loss2 = MSELoss(reduction="sum")  # 创建一个实例
loss3 = MSELoss(reduction="none")  # 创建一个实例
result1 = loss1(inputs, targets)
result2 = loss2(inputs, targets)
result3 = loss3(inputs, targets)

print(result1)
print(result2)
print(result3)

输出

tensor(1.3333)
tensor(4.)
tensor([[[[0., 0., 4.]]]])

3.损失函数在模型中的实现

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()  # 初始化父类属性
        self.model1 = Sequential(
            Conv2d(3, 32, 5, stride=1, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=1)

loss = nn.CrossEntropyLoss()  # 交叉熵
model = Model()  # 创建一个实例
for data in dataloader:
    imgs, targets = data
    output = model(imgs)
    result_loss = loss(output, targets)
    print(result_loss)  # 输出误差

输出

Files already downloaded and verified
tensor(2.4058, grad_fn=<NllLossBackward0>)
tensor(2.2368, grad_fn=<NllLossBackward0>)
tensor(2.2519, grad_fn=<NllLossBackward0>)
...
...

二、反向传播

参考
深度学习 | 反向传播详解

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()  # 初始化父类属性
        self.model1 = Sequential(
            Conv2d(3, 32, 5, stride=1, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=1)

loss = nn.CrossEntropyLoss()  # 交叉熵
model = Model()  # 创建一个实例
for data in dataloader:
    imgs, targets = data
    output = model(imgs)
    result_loss = loss(output, targets)
    result_loss.backward()  # 反向传播
    print(result_loss)  # 输出误差

Bebug模型:查看卷积核的梯度
model -> Protected Attributes -> _models -> 0 -> weight -> grad
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/167038.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaEE】基于TCP的客户端服务器程序

✨哈喽&#xff0c;进来的小伙伴们&#xff0c;你们好耶&#xff01;✨ &#x1f6f0;️&#x1f6f0;️系列专栏:【JavaEE】 ✈️✈️本篇内容:基于TCP的客户端服务器程序。 &#x1f680;&#x1f680;代码存放仓库gitee&#xff1a;JavaEE初阶代码存放&#xff01; ⛵⛵作者…

【JavaEE初阶】第二节.进程篇

文章目录 前言 一、操作系统 二、进程 2.1 进程的概念 2.2 进程的管理​​​​​​​​​​​​​​ 2.3 PCB 2.3.1 PCB里面的一些属性 2.3.2 进程的调度 2.3.3 进程的虚拟地址空间 2.3.4 进程间通信 总结 前言 本节内容我们继续对JavaEE的有关内容进行学习&#xff0c;…

汽车智能化,集度做加法

CES2023刚刚落下帷幕&#xff0c;这场名为“国际消费电子展”的业界盛会&#xff0c;近几年重心正明显转向智能汽车及其周边产业链。在2022年的CES上&#xff0c;集度与英伟达宣布合作&#xff0c;也透露了智能汽车研发的相关计划。而在本届CES之前、2022年末的广州车展上&…

一个关于image访问图片跨域的问题

一、背景 项目中遇到一个问题&#xff0c;同一个图片在 dom 节点中使用了 img 标签来加载&#xff0c;同时由于项目使用了 ThreeJS 3D 渲染引擎&#xff0c;在加载纹理时使用了 TextureLoader 来加载了同一张图片&#xff0c;而由于图片是在阿里云服务器上的&#xff0c;所以最…

SourceTree 拉取、重置提交、回滚、变基与合并

SourceTree的重置当前分支到此次提交 使用场景&#xff1a;“我想把已提交未推送的修改撤销” 使用模式说明软合并软合并是指将此次提交回滚到指定提交位置&#xff0c;但这个过程中会将修改过的文件暂存到暂存区。混合合并混合合并是指将此次提交回滚到指定的位置&#xff0c…

本来挺喜欢刷《剑指offer》的.......(第十一天)

跟着博主一起刷题 这里使用的是题库&#xff1a; https://leetcode.cn/problem-list/xb9nqhhg/?page1 目录剑指 Offer 66. 构建乘积数组剑指 Offer 68 - I. 二叉搜索树的最近公共祖先剑指 Offer 68 - II. 二叉树的最近公共祖先剑指 Offer 66. 构建乘积数组 剑指 Offer 66. 构建…

使用react-bmapgl绘制区域并判断是否重叠

需求如下&#xff1a; 在react项目中使用百度地图实现区域&#xff08;电子围栏&#xff09;的绘制绘制的区域类型为&#xff1a;1、多边形 2、圆形可绘制多个区域区域不能有重叠可重新编辑区域 代码如下: index.tsx import { useCallback, useEffect, useState } from rea…

Python入门实践(二)——变量的使用

文章目录变量1、变量的命名和使用1.1、避免命名错误2、字符串2.1、修改字符串大小写2.2、合并&#xff08;拼接&#xff09;字符串2.3、使用制表符或换行符来添加空白2.4、删除空白3、数字3.1、整数3.2、浮点数3.3、使用str()避免类型错误4、注释变量是对一种数据结构的命名&am…

2023年基建工程(设计规划施工)经验分享,超多干货

为了彻底打通从工程外业勘探调查、数据资料整理&#xff0c;到内业详细设计之间的一系列障碍&#xff0c;结合工程外业调查的特点&#xff0c;基于安卓&#xff08;Android&#xff09;操作系统&#xff0c;精心打磨推出了“外业精灵”移动端应用软件。 该系统把工程外业探勘、…

MPP数据库简介及架构分析

目录什么是MPP&#xff1f;特性并行处理超大规模数据仓库真正适合什么典型的分析工作量数据集中化线性可伸缩性MPP架构技术特性数据库架构分析Shared EverythingShared DiskShare MemoryShared NothingShared Nothing数据库架构优势什么是MPP&#xff1f; MPP (Massively Paral…

分享88个C源码,总有一款适合您

C源码 分享88个C源码&#xff0c;总有一款适合您 下面是文件的名字&#xff0c;我放了一些图片&#xff0c;文章里不是所有的图主要是放不下...&#xff0c;大家下载后可以看到。 源码下载链接&#xff1a;https://pan.baidu.com/s/1TT87gt66kn5BtLqgRUTlUQ?pwdwje5 提取码…

Java图形化界面---JOptionPane

目录 一、JOptionPane的介绍 二、JOptionalPane的使用 &#xff08;1&#xff09;消息对话框 &#xff08;2&#xff09; 确认对话框 &#xff08;3&#xff09;输入对话框 &#xff08;4&#xff09;选项对话框 一、JOptionPane的介绍 通过JOptionPane可以非常方便地创建…

SpringCloud复习之Sleuth+Zipkin链路追踪实战

文章目录写作背景为什么要有链路监控SpringCloud SleuthZipkin能做什么上手实战启动一个Zipkin Server微服务集成SleuthZipkin写作背景 前面复习了SpringCloud Netflix的几个核心组件&#xff0c;包括Eureka、Ribbon、Feign、Hystrix、Zuul&#xff0c;并进行了Demo级别的实战…

高精度减法【c++】超详细讲解

前言 大家学过高精度加法之后&#xff0c;可能已经知道高精度减法的实现方法了吧 如果你还没有学过高精度加法的话&#xff0c;请点击这里&#xff08;很详细的&#xff09;—>高精度加法【C实现】详解 最大的问题 最大的问题莫过于负数问题了。其他方法和加法一样。 负…

4.二级缓存解析

文章目录1. 二级缓存配置2. 二级缓存结构3. 二级缓存命中条件4. 缓存空间的理解5. 二级缓存执行流程二级缓存也称作是应用级缓存&#xff0c;与一级缓存不同的&#xff0c;是它的作用范围是整个应用&#xff0c;而且可以跨线程使用。所以二级缓存有更高的命中率&#xff0c;适合…

从南丁格尔图到医学发展史

可视化中&#xff0c;前端用于表现不同类目的数据在总和中的占比的场景&#xff0c;往往会采用饼图。 针对数据大小相近&#xff0c;南丁格尔图的呈现会更加美观。 南丁格尔图&#xff0c;又称玫瑰图&#xff0c;是由弗罗伦斯南丁格尔发明。 弗洛伦斯南丁格尔 开创了护理事业…

二、django中的路由系统

django中的路由系统 django中路由的作用和路由器类似&#xff0c;当一个用户请求Django站点的一个页面时&#xff0c;是路由系统通过对url的路径部分进行匹配&#xff0c;一旦匹配成功就导入并执行对应的视图来返回响应。 django如何处理请求 当一个请求来到时&#xff0c;d…

SpringSecurityOauth2架构Demo笔记

总体分为SpringSecurityOauth2授权码模式演示和密码模式演示 一直下一步,依赖手动导入,SpringBoot版本改成2.2.5.RELEASE,JDK版本1.8 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xml…

Open3D 点云投影至指定球面(Python版本)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 假设球体的相关参数:中心为 C ( x c , y c , z c ) C(x_c,y_c,z_c)

【数据结构和算法】栈—模拟实现Stack和栈相关算法题

文章目录栈的定义Stack模拟实现相关算法题1.栈的压入弹出序列2.逆波兰表达式(后缀表达式)⭐1.什么是逆波兰表达式?如何转换成逆波兰表达式逆波兰表达式如何计算3.有效的括号总结栈的定义 栈作为一种数据结构&#xff0c;是一种只能在一端进行插入和删除操作的特殊线性表。它按…