pytorch初学笔记(十四):损失函数

news2024/11/26 22:52:09

 目录

一、损失函数 

1.1 L1损失函数

1.1.1 简介

1.1.2 参数设定

1.1.3 代码实现

1.2 MSE损失函数(平方和)

1.2.1 简介

1.2.2 参数介绍

1.2.3 代码实现

1.3 损失函数的作用

二、在神经网络中使用loss function

2.1 使用交叉熵损失函数 

2.2 反向传播


一、损失函数 

torch.nn — PyTorch 1.13 documentation

每一个样本经过模型后会得到一个预测值,然后得到的预测值和真实值的差值就成为损失(当然损失值越小证明模型越是成功),我们知道有许多不同种类的损失函数,这些函数本质上就是计算预测值和真实值的差距的一类型函数,然后经过库(如pytorch,tensorflow等)的封装形成了有具体名字的函数。

1.1 L1损失函数

1.1.1 简介

L1损失函数: 基于逐像素比较差异,然后取绝对值。

在这里插入图片描述 

1.1.2 参数设定 

L1Loss — PyTorch 1.13 documentation

 CLASS torch.nn.L1Loss(size_average=Nonereduce=Nonereduction='mean')

 

我们一般设定reduction的值来显示平均值或者和。

 参数设定:

reduction可取的值: 'none' | 'mean' | 'sum'

  • 'none': no reduction will be applied
  •  'mean': the sum of the output will be divided by the number of elements in the output,求的是平均值,即各个差求和之后除以总数。

  •  'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction.只求和,不除总数。

  • Default: 'mean'

1.1.3 代码实现

 设置reduction的值为默认值和sum,观察区别。

import torch
from torch.nn import L1Loss

inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)

inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))

loss1 = L1Loss()
result1 = loss1(inputs,targets)
print(result1)

loss2 = L1Loss(reduction="sum")
result2 = loss2(inputs,targets)
print(result2)
  • 当取值为默认值mean时,求的是平均值,sum=(1-1+2-2+5-3)=2, n=3, result = sum/n=0.6667 
  • 当取值为sum时,求的是和,即result=2

 

1.2 MSE损失函数(平方和)

1.2.1 简介

均方误差(Mean Square Error,MSE)是回归损失函数中最常用的误差,它是预测值f(x)与目标值y之间差值平方和的均值,其公式如下所示:
在这里插入图片描述 

 

1.2.2 参数介绍

MSELoss — PyTorch 1.13 documentation

 

与上面的L1损失函数一样,我们可以改变reduction的值来进行对应数值的输出。

 

1.2.3 代码实现

import torch
from torch.nn import L1Loss, MSELoss

inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)

inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))

loss_mse1 = MSELoss()
result1 = loss_mse1(inputs,targets)
print(result1)

loss_mse2 = MSELoss(reduction="sum")
result2 = loss_mse2(inputs,targets)
print(result2)

 可以看到reduction设置不同的值对应的输出也不同。

 

1.3 损失函数的作用

  1. 计算实际输出和目标之间的差距
  2. 为更新输出(反向传播)提供一定的依据

二、在神经网络中使用loss function

2.1 使用交叉熵损失函数 

使用上次定义的神经网络和CIFAR10数据集进行图像分类,分类问题使用交叉熵损失函数。

import torch.nn
from torch import nn
import torchvision.datasets
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)

class Maweiyi(torch.nn.Module):
    def __init__(self):
        super(Maweiyi, self).__init__()
        self.model1 = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(in_features=1024, out_features=64),
            Linear(in_features=64, out_features=10)
        )

    def forward(self, x):
         x = self.model1(x)
         return x

maweiyi = Maweiyi()
# 使用交叉熵损失函数
loss_cross = nn.CrossEntropyLoss()

for data in dataloader:
    imgs,labels = data
    outputs = maweiyi(imgs)
    results = loss_cross(outputs,labels)
    print(results)

可以看到使用loss function计算出了在神经网路中预测的output和真实值labels之间的差距大小。 

 

2.2 反向传播

results_loss.backward() 

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/50089.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【软件测试】资深测试聊一聊,测试架构师是怎么样的,做一名成功的测试工程师......

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 测试架构师 测试架…

利尔达5G模组NE16U-CN通过华为OpenLab基于R16标准的认证测试

近日,利尔达5G R16模组NE16U-CN 率先顺利通过了华为OpenLab的认证测试,成为首批基于展锐V516芯片平台通过华为认证测试的5G模组,实现了基于3GPP R16协议版本的业务验证。 这表明,利尔达NE16U-CN模组已支持3GPP R16所具有的5G LAN、…

Overview of Computer Graphics

ContentsWhat is Computer Graphics?Why study Computer Graphics?ApplicationsFundamental Intellectual ChallengesTechnical ChallengesCourse TopicsRasterization (光栅化)Curves and Meshes (曲线和曲面)Ray Tracing (光线追踪)Animation / Simulation (动画 / 模拟)Re…

ANACONDA的进阶理解和思考

0. 继续深入了解anaconda 0.1 Anaconda 是 Python 的一个开源发行版本 里面集成了很多关于 python 科学计算的第三方库,主要面向科学计算且安装方便,而 python 是一个编译器 如果不使用 anaconda,那么安装库的时候,库的依赖安装起…

力扣LeetCode算法题 第6题-Z 字形变换

要求: 一开始看到题目,第一想到的思路,就被题目要求的思路给带偏了。 内容是Z字型输出内容 就一直想着把字符串输出成上面这种格式 总是想着把字符串放入到二维数组中进行展示。 这样一来思路就受到了限制。 一直使用先写入数组中。 //将…

直播邀请函 | 第12届亚洲知识产权营商论坛:共建创新价值 开拓崭新领域

由香港特别行政区政府、香港贸易发展局及香港设计中心共同举办的亚洲知识产权营商论坛,每年为世界各地知识产权业界专家、商界领袖提供一个理想平台,共同探讨亚洲知识产权市场的最新发展,发掘更多商机。 去年,论坛共邀请70余位国…

使用HBuilder X开发Vue3+node+element-plus(一)

开发Vue3有很多的工具,比如VSCode,它也非常的好用,本文主要使用HBuilder X开发。 环境3个: Windows10 Node安装 1.打开官网,选择一个版本,进行安装 Node.js 2.选择路径,下一步就行了 3. 输…

【深度学习】torch.argmax()函数讲解 | pytorch

文章目录前言一、两个维度的张量使用torch.argmax()函数二、三个维度的张量使用torch.argmax()函数前言 这篇博客也是属于看了好久一直没写,终于写了。 一、两个维度的张量使用torch.argmax()函数 我们直接先举一个例子吧,我们随机生成一个2X3的张量&…

[附源码]SSM计算机毕业设计亿上汽车在线销售管理系统JAVA

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

11 Daemonset:忠实可靠的看门狗

文章目录1. 前言2. 为什么要有 DaemonSet(看门狗)?3. 如何使用 YAML 描述 DaemonSe?3.1 参考官网创建DaemonSet YAML3.1.1 DaemonSet YAML 和 Deployment YAML 文件对比3.1.2 DaemonSet YAML 和 Deployment YAML 文件对比图示3.2 用变通的方法来创建 DaemonSet 的 …

【Python模块】图形化编程模块-turtle

Turtle,也称海龟渲染器,是 Python 内置的图形化模块,它使用 tkinter 实现基本图形界面,因此 当前使用的 Python 环境需要支持 tkinter。 Turtle 提供了面向对象和面向过程两种形式的海龟绘图基本组件。使用它可以轻松的实现图形的…

初探Golang语法巩固复习

最近在家,重新拾起Go语言,搭建环境可参考之前博客【初探Golang语言之环境搭建】,本文是基本语法熟悉与练习,方便备查。 判断语句 if && switch if 通过指定一个或多个条件,并通过测试条件是否为true来决定是…

[附源码]计算机毕业设计springboot驾校预约管理系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

xcode swift 单元测试 test

XCTest是苹果官方的测试框架&#xff0c;是基于OCUnit的传统测试框架&#xff0c;测试编写起来非常简单。 测试案例一 创建一个单元测试 func testExample() throws {let personID:String "0123456789"let count personID.countXCTAssert(count < 10, "I…

aPaaS是什么(aPaaS平台和IPaaS的区别是啥?大白话解释)

依题&#xff1a;aPaaS是什么&#xff1f;aPaaS与iPaaS二者之间的区别在哪&#xff1f;要想了解区别&#xff0c;首先得搞清概念&#xff0c;不然就是在耍流氓&#xff01;下面本人就从概念到区别用大白话给你一次性讲清楚。 一、什么是aPaaS&#xff1f; 应用程序平台即服务&…

freeswitch配置SBC的方案

概述 freeswitch 是一款好用的开源软交换平台。 但是&#xff0c;fs不是专为SBC而开发的&#xff0c;所以需要做一些定制化的配置和开发。 本文主要介绍如何利用fs的基本功能配置一个简单的SBC方案&#xff0c;满足一般化需求&#xff0c;如果有定制化的需求需要定制开发。 …

QT简单串口通信终端实现

1.工程文件 工程文件中添加serilport QT serialport 2.主程序 主程序文件main.cpp #include "mainwindow.h" #include <QApplication> int main(int argc, char *argv[]) { QApplication a(argc, argv); MainWindow w; w.show(); return a.exec(); } …

xray长亭是自动化Web漏洞扫描神器

xray长亭一款完善的安全评估工具&#xff0c;支持常见 web 安全问题扫描和自定义 。 xray 是一款功能强大的安全评估工具&#xff0c;由多名经验丰富的一线安全从业者呕心打造而成&#xff0c;主要特性有: 检测速度快。发包速度快; 漏洞检测算法效率高。支持范围广。大至 OWAS…

Python error:Compressed file ended before the end-of-stream marker was reached

功能描述 在做http协议处理时&#xff0c;经常遇到gzip格式的数据需要进行还原解压缩处理。 解压缩用到的Python库为 import gzip 报错 unpack_gzip error:Compressed file ended before the end-of-stream marker was reached 压缩文件在到达流结束标记之前结束 原因 该…

上海亚商投顾:沪指创反弹新高 房地产板块掀涨停潮

上海亚商投顾前言&#xff1a;无惧大盘大跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 市场情绪三大股指今日窄幅震荡&#xff0c;最终尾盘小幅收红。房地产板块午后跳水&#xff0c;首开股份跌停&#xff0c;粤宏…