8-pytorch-损失函数与反向传播

news2024/10/6 1:34:38

b站小土堆pytorch教程学习笔记

根据loss更新模型参数
1.计算实际输出与目标之间的差距
2.为我们更新输出提供一定的依据(反向传播)

在这里插入图片描述

1 MSEloss

import torch
from torch.nn import L1Loss
from torch import nn

inputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)

inputs=torch.reshape(inputs,(-1,1,1,3))
targets=torch.reshape(targets,(-1,1,1,3))

loss=L1Loss()
result=loss(inputs,targets)

loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)

print(result)
print(result_mse)

tensor(0.6667)
tensor(1.3333)

2 Cross EntropyLoss

在这里插入图片描述

x=torch.tensor([0.1,0.2,0.3])#需要reshape为要求的(batch_size,class)
y=torch.tensor([1])#target已经为要求的batch_size无需reshape
x=torch.reshape(x,(-1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(result_cross)

tensor(1.1019)

3 在具体的神经网络中使用loss

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset=torchvision.datasets.CIFAR10('dataset',train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)
dataloader=DataLoader(dataset,batch_size=1)

class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x

loss=nn.CrossEntropyLoss()
han=Han()
for data in dataloader:
    imgs,target=data
    output=han(imgs)
    # print(target)
    # print(output)
    result_loss=loss(output,target)
    print(result_loss)

*tensor([7])
tensor([[ 0.0057, -0.0201, -0.0796, 0.0556, -0.0625, 0.0125, -0.0413, -0.0056,
0.0624, -0.1072]], grad_fn=)…

tensor(2.2664, grad_fn=)…

4 反向传播 优化器

  1. 定义优化器
  2. 将待更新的每个参数梯度清零
  3. 调用损失函数的反向传播函数求出每个节点的梯度
  4. 使用step函数对模型的每个参数调优
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset=torchvision.datasets.CIFAR10('dataset',train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)
dataloader=DataLoader(dataset,batch_size=64)

class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x

loss=nn.CrossEntropyLoss()
han=Han()
optim=torch.optim.SGD(han.parameters(),lr=0.01)

for epoch in range(5):
    running_loss=0.0#一个epoch结束的loss和
    for data in dataloader:
        imgs,target=data
        output=han(imgs)

        result_loss=loss(output,target)#每次迭代的loss
        optim.zero_grad()#将网络中每个可调节参数对应的梯度调为0
        result_loss.backward()#优化器需要每个参数的梯度,使用反向传播获得
        optim.step()#对每个参数调优
        running_loss=running_loss+result_loss
    print(running_loss)

Files already downloaded and verified
tensor(361.0316, grad_fn=)
tensor(357.6938, grad_fn=)
tensor(343.0560, grad_fn=)
tensor(321.8132, grad_fn=)
tensor(313.3173, grad_fn=)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1469040.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WEB相关工具(wget、curl、ab)

目录 一、wget 1、wget基本语法 2、wget帮助的更多选项 二、curl 1、curl基本语法 2、curl命令基本用法 2.1 curl伪装 2.2 提取状态码 2.3 提取本地IP地址 2.4 提取远端服务器IP地址 2.5 提取本地端口 2.6 提取远端服务器端口 三、压力测试工具 1、常用的httpd压…

数据结构与算法相关题解20240225

数据结构与算法相关题解20240225 一、58. 最后一个单词的长度二、48. 旋转图像三、69. x 的平方根四、50. Pow(x, n) 一、58. 最后一个单词的长度 简单 给你一个字符串 s,由若干单词组成,单词前后用一些空格字符隔开。返回字符串中 最后一个 单词的长度…

基于springboot+vue的租房管理系统(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

快速查找/打开host文件的方法

hosts文件是一个没有扩展名的文件,主要作用是:保存与域名的映射关系。 配置格式: ip 域名 windows系统里的保存位置: C:\Windows\System32\drivers\etc 下面介绍快速打开的方法。 第一步 [winR]打开运行,输入下面的…

精品基于SpringBoot+Vue的常规应急物资管理系统

《[含文档PPT源码等]精品基于SpringBootVue的常规应急物资管理系统[包运行成功]》该项目含有源码、文档、PPT、配套开发软件、软件安装教程、项目发布教程、包运行成功! 软件开发环境及开发工具: Java——涉及技术: 前端使用技术&#xff…

应用回归分析:贝叶斯回归

贝叶斯回归是一种统计方法,它利用贝叶斯定理来更新对回归参数的估计。这种方法不仅考虑了数据的不确定性,还考虑了模型参数的不确定性,为预测提供了一个更加全面的框架。在本文中,我们将深入探讨贝叶斯回归的基本概念、如何实现它…

Docker容器实战

"爱在,地图上,剥落~" Mysql 容器化安装 我们可以在 docker hub上,进入mysql的镜像仓库,找到适合的版本。 直接拉取镜像: docker pull mysql:latest 我们知道 msyql 的默认端口是 3306 ,而且有密码&#x…

ArcgisForJS如何将ArcGIS Server发布的点要素渲染为热力图?

文章目录 0.引言1.ArcGIS创建点要素2.ArcGIS Server发布点要素3.ArcgisForJS将ArcGIS创建的点要素渲染为热力图 0.引言 ArcGIS For JS 是一个强大的地理信息系统(GIS)工具,它允许开发者使用 JavaScript 语言来创建各种 GIS 应用。ArcGIS Ser…

2.5G/5G/10G高速率网络变压器(网络隔离变压器)产品介绍(1)

Hqst华轩盛(石门盈盛)电子导读:高速率/2.5G 的带POE插件(DIP)款千兆双口网络变压器2G54801DP特点 一 ﹑2.5G高速率网络变压器(网络隔离变压器):2G54801DP外观与尺寸 2G54801DP这颗产品尺寸为:长…

应用回归分析:非参数回归

非参数回归是一种统计方法,它在建模和分析数据时不假设固定的模型形式。与传统的参数回归模型不同,如线性回归和多项式回归,非参数回归不需要预先定义模型的结构(例如,模型是否为线性或多项式)。这使得非参…

Python爬虫-付费代理推荐和使用

付费代理的使用 相对免费代理来说,付费代理的稳定性更高。本节将介绍爬虫付费代理的相关使用过程。 1. 付费代理分类 付费代理分为两类: 一类提供接口获取海量代理,按天或者按量收费,如讯代理。 一类搭建了代理隧道&#xff0…

矩阵的导数运算(理解分子布局、分母布局)

矩阵的导数运算(理解分子布局、分母布局) 1、分子布局和分母布局 请思考这样一个问题,一个维度为m的向量y对一个标量x的求导,那么结果也是一个m维的向量,那么这个结果向量是行向量,还是列向量呢? 答案是&#xff1a…

故障诊断 | 一文解决,PSO-BP粒子群算法优化BP神经网络模型的故障诊断(Matlab)

文章目录 效果一览文章概述模型描述源码设计参考资料效果一览 文章概述 故障诊断 | 一文解决,PSO-BP粒子群算法优化BP神经网络模型的故障诊断(Matlab) 粒子群优化算法(Particle Swarm Optimization, PSO)是一种群体智能优化算法,用于求解优化问题。BP神经网络是一种用于模…

备战蓝桥杯————二叉树解题思维1

解决二叉树问题时,常采用两种思维模式: 遍历思维模式: 这种思维模式强调是否可以通过一次遍历二叉树来得到答案。通常使用一个遍历函数(比如前序、中序、后序遍历)结合外部变量来实现。这种方法适用于需要在每个节点上…

读人工不智能:计算机如何误解世界笔记02_Hello,world

1. Hello,world 1.1. “Hello,world”是布赖恩克尼汉和丹尼斯里奇于1978年出版的经典著作《C程序设计语言》中的第一个编程项目 1.2. 贝尔实验室可以说是现代计算机科学界中的智库,地位好比巧克力界的好时巧克力 1.3. 计算机科学界的大量创…

(响应数据)学习SpringMVC的第三天

响应数据 一 . 传统同步业务数据响应 1.1 请求资源转发与请求资源重定向的区别 请求资源转发时,froward:可不写 二 . 前后端分离异步方式 回写json格式的字符串 1 用RestController代替Controller与 ResponseBody 2 . 直接返回user对象实体 , 即可向 前端ajax 返回json字…

第七篇:CamX Sensor Bringup

第七篇:CamX Sensor Bringup 一、sensor 驱动文件编写 sensor驱动相关的文件目录在chi-cdk/oem/qcom/sensor 下。一般如果能直接从模组厂上拿到已经写好的驱动文件,那是最好的了。 如果没有,那就只能是拿到提供的寄存器setting参数,自己来写。 我们可以参考已有的驱动文…

【Linux基础】Linux自动化构建工具make/makefile

背景 会不会写makefile,从一个侧面说明了一个人是否具备完成大型工程的能力一个工程中的源文件不计数,其按类型、功能、模块分别放在若干个目录中,makefile定义了一系列的规则来指定,哪些文件需要先编译,哪些文件需要后…

异步http和同步http原理和差异

开发服务器端程序时,一种常见的需求是,通过向另一个http服务器发送请求,获得数据。最常规的作法是使用同步http请求的方式,过程如下 这种方式简单好用,但是在高并发场景下有缺陷。在单线程环境下,程序发送h…

台式电脑无法进桌面问题

楼主家里的台式电脑有一段时间进不了桌面,一度很困扰。 最开始发现有一个存储盘没有显示,拆开主机盖,把显卡、内存、硬盘都重新往紧压了下。重新开机后,显示器还是黑的。 表现为主机启动的声音正常,显示器没有信号接…