Pytorch个人学习记录总结 09

news2024/11/25 7:53:45

目录

损失函数与反向传播

L1Loss

MSELOSS

CrossEntropyLoss


损失函数与反向传播

所需的Loss计算函数都在torch.nn的LossFunctions中,官方网址是:torch.nn — PyTorch 2.0 documentation。举例了L1Loss、MSELoss、CrossEntropyLoss。


在这些Loss函数的使用中,有以下注意的点:
(1) 参数reduction='mean',默认是'mean'表示对差值的和求均值,还可以是'sum'则不会求均值。
(2) 一定要注意Input和target的shape。 

L1Loss

创建一个标准,用于测量中每个元素之间的Input: x xx 和 target: y yy。

创建一个标准,用来测量Input: x xx 和 target: y yy 中的每个元素之间的平均绝对误差(MAE)(L 1 L_1L 1范数)。

 

 

Shape:

Input: (∗ *∗), where ∗ *∗ means any number of dimensions. 会对所有维度的loss求均值
Target: (∗ *∗), same shape as the input. 与Input的shape相同
Output: scalar.返回值是标量。

假设 a aa 是标量,则有:

type(a) = torch.Tensor
a.shape = torch.Size([])
a.dim = 0
 

MSELOSS

创建一个标准,用来测量Input: x xx 和 target: y yy 中的每个元素之间的均方误差(平方L2范数)。


 

Shape:

Input: (∗ *∗), where ∗ *∗ means any number of dimensions. 会对所有维度求loss
Target: (∗ *∗), same shape as the input. 与Input的shape相同
Output: scalar.返回值是标量。


CrossEntropyLoss

该标准计算 input 和 target 之间的交叉熵损失。

非常适用于当训练 C CC 类的分类问题(即多分类问题,若是二分类问题,可采用BCELoss)。如果要提供可选参数 w e i g h t weightweight ,那 w e i g h t weightweight 应设置为1维tensor去为每个类分配权重。这在训练集不平衡时特别有用。

期望的 input应包含每个类的原始的、未标准化的分数。input必须是大小为C CC(input未分批)、(m i n i b a t c h , C minibatch,Cminibatch,C) or (m i n i b a t c h , C , d 1 , d 2 , . . . d kminibatch,C,d_1,d_2,...d_kminibatch,C,d 1,d 2,...d k )的Tensor。最后一种方法适用于高维输入,例如计算2D图像的每像素交叉熵损失。

期望的 target应包含以下内容之一:

(1) (target包含了)在[ 0 , C ) [0,C)[0,C)区间的类别索引,C CC是类别总数量。如果指定了 ignore_index,则此损失也接受此类索引(此索引不一定在类别范围内)。reduction='none'情况下的loss为

注意:l o g loglog默认是以10为底的。
 

x是input,y yy是target,w ww是权重weight,C CC是类别数量,N NN涵盖minibatch维度且d 1 , d 2 . . . , d k d_1,d_2...,d_kd 1,d 2...,d k分别表示第k个维度。(N太难翻译了,总感觉没翻译对)如果reduction='mean'或'sum',

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2, stride=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):  # 模型前向传播
        return self.model(x)


model = Model()  # 定义模型
loss_cross = nn.CrossEntropyLoss()  # 定义损失函数

for data in dataloader:
    imgs, targets = data
    outputs = model(imgs)
    # print(outputs)    # 先打印查看一下结果。outputs.shape=(2, 10) 即(N,C)
    # print(targets)    # target.shape=(2) 即(N)
    # 观察outputs和target的shape,然后选择使用哪个损失函数
    res_loss = loss_cross(outputs, targets)
    res_loss.backward()  # 损失反向传播
    print(res_loss)

#
# inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
# targets = torch.tensor([1, 2, 5], dtype=torch.float32)
#
# inputs = torch.reshape(inputs, (1, 1, 1, 3))
# targets = torch.reshape(targets, (1, 1, 1, 3))
#
# # -------------L1Loss--------------- #
# loss = nn.L1Loss()
# res = loss(inputs, targets)  # 返回的是一个标量,ndim=0
# print(res)  # tensor(1.6667)
#
# # -------------MSELoss--------------- #
# loss_mse = nn.MSELoss()
# res_mse = loss_mse(inputs, targets)
# print(res_mse)
#
# # -------------CrossEntropyLoss--------------- #
# x = torch.tensor([0.1, 0.2, 0.3])  # (N,C)
# x = torch.reshape(x, (1, 3))
# y = torch.tensor([1])  # (N)
# loss_cross = nn.CrossEntropyLoss()
# res_cross = loss_cross(x, y)
# print(res_cross)


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/794770.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Andrew算法求凸包模板

前置知识 向量的叉乘: 设 a ⃗ ( x a , y a , z a ) , b ⃗ ( x b , y b , z b ) \vec a(x_a,y_a,z_a), \vec b(x_b, y_b,z_b) a (xa​,ya​,za​),b (xb​,yb​,zb​), 令 a ⃗ \vec a a 和 b ⃗ \vec b b 的叉乘为 c ⃗ \vec c c , 有: c ⃗ ∣ i j k x a y a z a x b y…

微软对Visual Studio 17.7 Preview 4进行版本更新,新插件管理器亮相

近期微软发布了Visual Studio 17.7 Preview 4版本,而在这个版本当中,全新设计的扩展插件管理器将亮相,并且可以让用户可更简单地安装和管理扩展插件。 据了解,目前用户可以从 Visual Studio Marketplace 下载各式各样的 VS 扩展插…

[DASCTF 2023 0X401七月暑期挑战赛] viphouse复现

这个题想了好久,无果,终于看到WP。照着作了一遍。WP里没有详细解释,所以复现得很辛苦。 程序有5个菜单和1个初始化程序: init 先从os.random读8个字节放到src处login 读入用户名密码,其中密码栈里设的0x40可以读入0x…

【Ap模块EM】08-怎么让Execution Management成为第一个执行的进程?

前面的文章,我们讲述了ubuntu系统上电执行的流程,那么在Ap AutoSAR中Execution Management怎么成为第一个被执行的进程呢额?就是让它取代传统的init进程,成为ubuntu系统第一个执行的进程? 我们可以通过符号链接 symbolic link去实现,这个类似于windows系统中的某个exe文件…

mongoDB详解

mongoDB详解 mongodb是一个nosql数据库,它有高性能、无模式、文档型的特点。他是nosql数据库中功能最丰富,最像关系数据库的。 mongoDb基本介绍 mongodb里面有以下几个核心概念:文档:mongodb数据库的最小数据集,是由…

关于阿里云OSS服务器绑定域名及Https证书

这是一个没有套路的前端博主,热衷各种前端向的骚操作,经常想到哪就写到哪,如果有感兴趣的技术和前端效果可以留言~博主看到后会去代替大家踩坑的~ 主页: oliver尹的主页 格言: 跌倒了爬起来就好~ 关于阿里云…

安全学习DAY07_其他协议抓包技术

协议抓包技术-全局-APP&小程序&PC应用 抓包工具-Wireshark&科来分析&封包 TCPDump: 是可以将网络中传送的数据包完全截获下来提供分析。它支持针对网络层、协议、主机、网络或端口的过滤,并提供and、or、not等逻辑语句来帮助你去掉无用…

TypeError: Failed to fetch dynamically imported module

浏览器报了如下错误: vue文件如下: 错误出现的原因是因为导入的是路径,vue会在该路径下的文件夹搜索所有文件,但是没有找到对应的组件,但是浏览器并不会直接禁止访问,而是在控制台报错,解决办法…

量化交易——python数据分析及可视化

该项目分为两个部分:一是数据计算,二是可视化,三是MACD策略 一、计算MACD 1、数据部分 数据来源:tushare 数据字段包含:日期,开盘价,收盘价,最低价,最高价&#xff0c…

11.python设计模式【责任链模式】

内容:使多个对象都有机会处理请求,从而避免请求的发送者和接收者之间的耦合关系。将这些对象连成一条链,并沿着这条链传递该请求,直到有一个对象处理它为止。角色: 抽象处理者(Handler)具体处理…

Unity实现camera数据注入RMP推送或轻量级RTSP服务模块

技术背景 随着技术的不断进步和应用的不断深化,Unity3D VR应用的前景非常广阔,它广泛应用于教育、医疗、军事、工业设计、虚拟数字人等多个领域。 教育领域:Unity3D VR技术可以用来创建虚拟现实教室,让学生能够身临其境地体验课…

微信小程序 事件和语法

属性列表 target&#xff1a;触发事件的对象currentTarget&#xff1a;绑定事件的对象&#xff0c;可能不是由此对象触发而是触发对象冒泡导致触发此事件 微信小程序 事件绑定 <button type"primary" bindtap"btnTapHandler">按钮</button> …

pytorch学习-线性神经网络——softmax回归+损失函数+图片分类数据集

1.softmax回归 Softmax回归&#xff08;Softmax Regression&#xff09;是一种常见的多分类模型&#xff0c;可以用于将输入变量映射到多个类别的概率分布中。softmax回归是机器学习中非常重要并且经典的模型&#xff0c;虽然叫回归&#xff0c;实际上是一个分类问题 1.1分类与…

ORCA优化器浅析——Exploration and Implementation Apply CXform Phase

GPORCA通过规则对表达式做转换&#xff0c;规则分为两类&#xff1a;Exploration&#xff0c;是对逻辑表达式做等价变换的&#xff1b;Implementation&#xff0c;是把逻辑操作符转换为物理操作符。Except the types of operator generated during Exploration (Logical > L…

手机python怎么用海龟画图,python怎么在手机上编程

大家好&#xff0c;给大家分享一下手机python怎么用海龟画图&#xff0c;很多人还不知道这一点。下面详细解释一下。现在让我们来看看&#xff01; 1、如何python手机版创造Al&#xff1f; 如果您想在手机上使用Python来创建AI&#xff08;人工智能&#xff09;程序&#xff0…

服务器推送、在线游戏和电子邮件背后的网络协议

之前也聊了不少网络协议这块内容&#xff0c;现在我们将深入探讨关键的网络协议及其在不同应用中的作用。重点在于理解这些协议如何塑造我们在互联网上的通信和互动方式。我们将深入研究以下领域&#xff1a; WebSocket 在之前的讨论中&#xff0c;我们研究了HTTP及其在客户端和…

帮助中心客户服务中的作用与应用

随着互联网的快速发展&#xff0c;越来越多的企业开始寻求更加高效、便捷的客户服务方式&#xff0c;帮助中心的自助服务作为一种新型的客户服务模式&#xff0c;受到了广泛关注。本文将针对帮助中心的自助服务在客户服务中的作用与应用进行深入探讨。 帮助中心在客户服务中的…

scrcpy2.0+实时将手机画面显示在屏幕上并用鼠标模拟点击2023.7.26

想要用AI代打手游&#xff0c;除了模拟器登录&#xff0c;也可以直接使用第三方工具Scrcpy&#xff0c;来自github&#xff0c;它是一个开源的屏幕镜像工具&#xff0c;可以在电脑上显示Android设备的画面&#xff0c;并支持使用鼠标进行交互。 目录 1. 下载安装2. scrcpy的高级…

文本预处理——文本数据分析

目录 文本数据分析中文酒店评价语料获得训练集和验证集的标签数量分布获取训练集和验证集的句子长度分布获取训练集和验证集的正负样本长度散点分布获得训练集和验证集不同词汇总数统计获得训练集上正负的样本的高频形容词词云获得验证集上正负的样本的形容词词云 文本数据分析…

Golang指针详解

要搞明白Go语言中的指针需要先知道3个概念&#xff1a;指针地址、指针类型和指针取值。 指针介绍 我们知道变量是用来存储数据的&#xff0c;变量的本质是给存储数据的内存地址起了一个好记的别名。比如我们定义了一个变量 a : 10 ,这个时候可以直接通过 a 这个变量来读取内存…