pytorch常用的模块函数汇总(1)

news2024/11/24 16:36:06

目录

torch:核心库,包含张量操作、数学函数等基本功能

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

torch.autograd:自动求导模块,用于计算张量的梯度


torch:核心库,包含张量操作、数学函数等基本功能

  1. torch.tensor(): 创建张量
  2. torch.zeros()torch.ones(): 创建全零或全一张量
  3. torch.rand(): 创建随机张量
  4. torch.from_numpy(): 从 NumPy 数组创建张量
  5. torch.add()torch.sub()torch.mul()torch.div(): 加法、减法、乘法、除法

  6. torch.mm()torch.matmul(): 矩阵乘法

  7. torch.exp()torch.log()torch.sin()torch.cos(): 指数、对数、正弦、余弦等数学函数

  8. torch.index_select(): 按索引选取张量的子集

  9. torch.masked_select(): 根据掩码选取张量的子集

    切片操作:类似 Python 中的列表切片操作,如 tensor[2:5]
  10. torch.view()torch.reshape(): 改变张量的形状

  11. torch.squeeze()torch.unsqueeze(): 压缩或扩展张量的维度

  12. torch.mean()torch.sum()torch.max()torch.min(): 计算张量均值、和、最大值、最小值等

  13. torch.broadcast_tensors(): 对张量进行广播操作

  14. torch.cat(): 拼接张量

  15. torch.stack(): 堆叠张量

  16. torch.split(): 分割张量

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

  • 神经网络层

    • torch.nn.Linear(in_features, out_features): 全连接层,进行线性变换。
    • torch.nn.Conv2d(in_channels, out_channels, kernel_size): 2D卷积层。
    • torch.nn.MaxPool2d(kernel_size): 2D 最大池化层。
    • torch.nn.ReLU(): ReLU 激活函数。
    • torch.nn.Sigmoid(): Sigmoid 激活函数。
    • torch.nn.Dropout(p): Dropout 层,用于防止过拟合。

备注:Sigmoid 激活函数是一种常用的非线性激活函数,其作用可以总结如下:

将输入映射到 (0, 1) 范围内:输出范围在 0 到 1 之间,可以将任意实数输入映射到 0 到 1 之间。这种特性在某些情况下很有用,比如对于二分类任务,Sigmoid 函数的输出可以被解释为样本属于正类的概率。

引入非线性变换: Sigmoid 函数是一种非线性函数,可以引入神经网络的非线性变换能力,使得神经网络可以学习更加复杂的模式和关系。在深度神经网络中,非线性激活函数的使用可以帮助神经网络学习非线性模式,提高网络的表达能力。

输出平滑且连续: Sigmoid 函数具有平滑的 S 形曲线,在定义域内都是可导的,这使得在反向传播算法中计算梯度变得相对容易。这一点对于神经网络的训练至关重要。

  • 损失函数

torch.nn.CrossEntropyLoss(): 交叉熵损失函数,常用于多分类问题。

交叉熵损失函数用于衡量两个概率分布之间的差异,通常用于多分类任务中。在神经网络的多分类任务中,输入模型的输出是一个概率分布,表示每个类别的预测概率,而交叉熵损失函数则用于比较这个预测概率分布与实际标签的分布之间的差异。

torch.nn.CrossEntropyLoss() 来计算交叉熵损失函数,它会自动将模型的输出通过 Softmax 函数转换为概率分布,并计算交叉熵损失。

torch.nn.MSELoss(): 均方误差损失函数,常用于回归问题。

均方误差损失函数用于衡量模型输出与实际目标之间的差异,通常在回归任务中使用。该损失函数计算预测值与真实值之间的平方差,并将所有样本的平方差求平均得到最终的损失值。

  • 优化器

    • torch.optim.SGD(model.parameters(), lr=learning_rate): 随机梯度下降优化器。
    • torch.optim.Adam(model.parameters(), lr=learning_rate): Adam 优化器。
  • 模型定义相关

    • torch.nn.Module: 所有神经网络模型的基类,需要继承这个类。
    • model.forward(input_tensor): 定义前向传播。
  • 数据处理相关

    • torch.utils.data.Dataset: PyTorch 数据集的基类,需要自定义数据集时使用。
    • torch.utils.data.DataLoader(dataset, batch_size, shuffle): 数据加载器,用于批量加载数据。
  • torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

  • 优化器(Optimizer)类

    • torch.optim.SGD(params, lr=0.01, momentum=0, weight_decay=0):随机梯度下降优化器,实现了带动量的随机梯度下降
    • torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):Adam 优化器,结合了动量方法和 RMSProp 方法,通常在深度学习中表现良好。
    • torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0):Adagrad 优化器,自适应地为参数分配学习率。它根据参数的历史梯度信息对每个参数的学习率进行调整。这意味着对于不同的参数,Adagrad可以为其分配不同的学习率,从而更好地适应参数的更新需求
    • torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):RMSprop 优化器,有效地解决了 Adagrad 学习率下降较快的问题(RMSprop对梯度平方项进行指数加权平均)。

备注:

1.在优化算法中,momentum(动量)是一种用于加速模型训练的技巧。动量项的引入旨在解决标准随机梯度下降在训练过程中可能遇到的震荡和收敛速度慢的问题。

动量项的引入可以帮助优化算法在参数更新时更好地利用之前的更新方向,从而在一定程度上减少参数更新的波动,加快收敛速度,并有助于跳出局部极小值。具体来说,动量项在参数更新时会考虑之前的更新方向,并对当前的更新方向进行一定程度的调整。

在 PyTorch 的 torch.optim.SGD 中,动量可以通过设置 momentum 参数来控制。通常情况下,动量的取值范围在 0 到 1 之间,常见的默认取值为 0.9。当动量设为0.9时,每次迭代,都会保留上一次速度的 90%,并使用当前梯度微调最终的更新方向。

总结来说,动量项的引入可以提高随机梯度下降的稳定性和收敛速度,有助于在训练神经网络时更快地找到较优的解。

2. 

在优化算法中,weight decay(权重衰减)是一种用于控制模型参数更新的正则化技术。权重衰减通过在优化过程中对参数进行惩罚,防止其取值过大,从而有助于降低过拟合的风险。

具体来说,在SGD中的weight_decay参数是对模型的权重进行L2正则化,即在计算梯度时额外增加一个关于参数的惩罚项。这个惩罚项会使得优化算法更倾向于选择较小的权重值,从而降低模型的复杂度,减少过拟合的风险。

在PyTorch中,torch.optim.SGD中的weight_decay参数用于控制权重衰减的程度。通常情况下,weight_decay的取值为一个小的正数,比如 0.001 或 0.0001。设置了weight_decay之后,在计算梯度时会额外考虑到权重的惩罚项,从而影响参数的更新方式。

总结来说,权重衰减是一种正则化技术,通过对模型参数的惩罚来控制模型的复杂度,减少过拟合的风险,提高模型的泛化能力。

3.

  • L1正则化:L1正则化会给模型的损失函数添加一个关于权重绝对值的惩罚项,即L1范数(权重的绝对值之和)。在梯度下降过程中,L1正则化会导致部分权重直接变为0,因此可以实现稀疏性,有特征选择的效果。L1正则化倾向于产生稀疏的权重矩阵,可以用于特征选择和降维。

L1正则化的惩罚项是模型权重的L1范数,即权重的绝对值之和。在优化过程中,为了最小化损失函数并减少正则化项的影响,优化算法会尝试将权重调整到较小的值。由于L1正则化的几何形状在坐标轴上拐角处就会与坐标轴相交,这就导致了在坐标轴上许多点都是对称的,因此在这些点上的梯度不唯一。这意味着在这些对称点上,优化算法更有可能将权重调整为0,从而导致稀疏性。

  • L2正则化:L2正则化会给模型的损失函数添加一个关于权重平方的惩罚项,即L2范数的平方(权重的平方和)。在梯度下降过程中,L2正则化会使得权重都变得比较小,但不会直接导致稀疏性。L2正则化对异常值比较敏感,因为它会平方每个权重,使得异常值对损失函数的影响更大。

  • 总的来说,L1正则化和L2正则化都是常用的正则化技术,它们在模型训练过程中都有助于控制模型的复杂度,减少过拟合的风险。选择使用哪种正则化方法通常取决于具体的问题和数据特点,以及对模型稀疏性的需求。在实际应用中,有时也会将L1和L2正则化结合起来,形成弹性网络正则化(Elastic Net regularization),以兼顾两种正则化方法的优势。

4. 

betas 是 Adam 算法中的两个超参数之一,它控制了梯度的一阶矩估计和二阶矩估计的指数衰减率。betas 是一个长度为2的元组,通常形式为 (beta1, beta2)。在 Adam 算法中,beta1 控制了一阶矩估计(梯度的均值)的衰减率,beta2 控制了二阶矩估计(梯度的平方的均值)的衰减率。

通常情况下,beta1 的默认值为 0.9,beta2 的默认值为 0.999。这意味着在每次迭代中,一阶矩估计将保留当前梯度的 90%,而二阶矩估计将保留当前梯度的平方的 99.9%。这些衰减率的选择使得 Adam 算法能够在训练过程中自适应地调整学习率,并对梯度的变化做出快速或缓慢的响应,从而更有效地更新模型参数。

总之,betas 参数在 Adam 算法中起着调节梯度一阶和二阶矩估计衰减率的作用,通过合理设置 betas 可以影响算法的收敛性和稳定性。

  • 调整学习率的函数

    • torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)根据给定的函数 lr_lambda 调整学习率。
    • torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1):每个 step_size 个 epoch 将学习率降低为原来的 gamma 倍。
    • torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1):在指定的里程碑上将学习率降低为原来的 gamma 倍。
  • 其他常用函数

    • zero_grad():用于将模型参数的梯度清零,通常在每个 batch 后调用。
    • step(closure):用于执行单步优化器的更新,需要传入一个闭包函数 closure
    • state_dict() 和 load_state_dict():用于保存和加载优化器的状态字典,方便恢复训练。
  • torch.autograd:自动求导模块,用于计算张量的梯度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SpringBoot的“校园台球厅人员与设备管理系统”的设计与实现(源码+数据库+文档+PPT)

基于SpringBoot的“校园台球厅人员与设备管理系统”的设计与实现(源码数据库文档PPT) 开发语言:Java 数据库:MySQL 技术:SpringBoot 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 系统功能结构图 系统首页界面图…

信息工程大学第五届超越杯程序设计竞赛(同步赛)题解

比赛传送门 博客园传送门 c 模板框架 #pragma GCC optimize(3,"Ofast","inline") #include<bits/stdc.h> #define rep(i,a,b) for (int ia;i<b;i) #define per(i,a,b) for (int ia;i>b;--i) #define se second #define fi first #define e…

YOLOv8算法改进【NO.111】利用shift-wise conv对顶会提出EMO中的iRMB进行二次创新

前 言 YOLO算法改进系列出到这&#xff0c;很多朋友问改进如何选择是最佳的&#xff0c;下面我就根据个人多年的写作发文章以及指导发文章的经验来看&#xff0c;按照优先顺序进行排序讲解YOLO算法改进方法的顺序选择。具体有需求的同学可以私信我沟通&#xff1a; 首推…

Vue + .NetCore前后端分离,不一样的快速发开框架

摘要&#xff1a; 随着前端技术的快速发展&#xff0c;Vue.NetCore框架已成为前后端分离开发中的热门选择。本文将深入探讨Vue.NetCore前后端分离的快速开发框架&#xff0c;以及它如何助力开发人员提高效率、降低开发复杂度。文章将从基础功能、核心优势、适用范围、依赖环境等…

Qt for WebAssembly 环境搭建 - Windows新手入门

Qt for WebAssembly 环境搭建 - Windows新手入门 一、所需工具软件1、安装Python2、安装Git2.1 注册Github账号2.2 下载安装Git2.2.1配置Git&#xff1a;2.2.2 配置Git环境2.2.3解决gitgithub.com: Permission denied (publickey) 3 安装em编译器 二、Qt配置编译器三、参考链接…

wireshark 使用

wireshark介绍 wireshak可以抓取经过主机网卡的所有数据包&#xff08;包括虚拟机使用的虚拟网卡的数据包&#xff09;。 环境安装 安装wireshark: https://blog.csdn.net/Eoning/article/details/132141665 安装网络助手工具&#xff1a;https://soft.3dmgame.com/down/213…

【Java SE】深入理解static关键字

&#x1f970;&#x1f970;&#x1f970;来都来了&#xff0c;不妨点个关注叭&#xff01; &#x1f449;博客主页&#xff1a;欢迎各位大佬!&#x1f448; 文章目录 1.static关键字1.1 static的概念1.2 static的作用1.3 static的用法1.3.1 static修饰成员变量1.3.2 static修饰…

记一次由gzip引起的nginx转发事故

故事背景 书接前几篇文章&#xff0c;仍然是交付甲方遇到的一个特殊诉求&#xff0c;从而引发了本期的事故。甲方的诉求是前端的请求过来&#xff0c;需要加密&#xff0c;但是要经过waf&#xff0c;必须要求是请求明文&#xff0c;那就要在waf和nginx之间做一个解密前置应用处…

【3】3道链表力扣题:删除链表中的节点、反转链表、判断一个链表是否有环

3道链表力扣题 一、删除链表中的节点&#x1f30f; 题目链接&#x1f4d5; 示例&#x1f340; 分析&#x1f4bb; 代码 二、反转链表&#x1f30f; 题目链接&#x1f4d5; 示例&#x1f340; 分析① 递归② 迭代 三、判断一个链表是否有环&#x1f30f; 题目链接&#x1f4d5; …

学习transformer模型-Input Embedding 嵌入层的简明介绍

今天介绍transformer模型的Input Embedding 嵌入层。 背景 嵌入层的目标是使模型能够更多地了解单词、标记或其他输入之间的关系。 从头开始嵌入Embeddings from Scratch 嵌入序列需要分词器tokenizer、词汇表和索引&#xff0c;以及词汇表中每个单词的三维嵌入。Embedding a s…

git clone 后如何 checkout 到 remote branch

what/why 通常情况使用git clone github_repository_address下载下来的仓库使用git branch查看当前所有分支时只能看到master分支&#xff0c;但是想要切换到其他分支进行工作怎么办❓ 其实使用git clone下载的repository没那么简单&#x1f625;&#xff0c;clone得到的是仓库…

一个 hipsolver 特征值示例

1&#xff0c;原理 通过雅可比旋转&#xff0c;对对称矩阵计算特征值和特征向量&#xff1b; 通过初等正交变换&#xff0c;每次把其中一个非主对角元素消成零&#xff0c;最终只剩主对角线非零元素为特征值&#xff0c;同时把初等变换累积下来&#xff0c;构成特征向量。 2&a…

使用1panel部署Ollama WebUI(dcoekr版)浅谈

文章目录 说明配置镜像加速Ollama WebUI容器部署Ollama WebUI使用问题解决&#xff1a;访问页面空白 说明 1Panel简化了docker的部署&#xff0c;提供了可视化的操作&#xff0c;但是我在尝试创建Ollama WebUI容器时&#xff0c;遇到了从github拉取镜像网速很慢的问题&#xf…

Java反射系列(3):从spring反射工具ReflectionUtils说起

传送门 在比较早的时候&#xff0c;就讨论过java反射的一些用法及概念&#xff1a; Java反射系列(1)&#xff1a;入门基础 以及反射的基石Class对象&#xff01; Java反射系列(2)&#xff1a;从Class获取父类方法说起 今天就从工作中实际的例子来看看反射的应用。 兼容…

pytest--python的一种测试框架--request请求加入headers

一、request headers中的cookie和session机制的作用与区别 Cookie 和 Session 是两种在客户端和服务器之间保持状态的技术。HTTP 协议本身是无状态的&#xff0c;这意味着服务器无法从上一次的请求中保留任何信息到下一次请求。Cookie 和 Session 机制就是为了解决这个问题。 …

240330-大模型资源-使用教程-部署方式-部分笔记

A. 大模型资源 Models - Hugging FaceHF-Mirror - Huggingface 镜像站模型库首页 魔搭社区 B. 使用教程 HuggingFace HuggingFace 10分钟快速入门&#xff08;一&#xff09;&#xff0c;利用Transformers&#xff0c;Pipeline探索AI。_哔哩哔哩_bilibiliHuggingFace快速入…

遥感数字图像处理的学习笔记

相关链接&#xff1a; 遥感数字图像处理实验教程&#xff08;韦玉春&#xff09;--部分实验问题回答 目录 1.什么是图像&#xff0c;什么是数字图像&#xff1f; 2.什么是遥感数字图像&#xff1f;模拟图像(照片)与遥感数字图像有什么区别&#xff1f; 3.什么是遥感数字图像…

小小狠招:巧妙使用HANA数据库的jdbc driver

SAP旗下的HANA数据库&#xff0c;实际上是分为两个系列进行发布&#xff0c;一种是基于本地部署的称之为HANA Platform。另一种是面向Cloud平台的&#xff0c;称之为HANA Cloud。 在实际使用当用&#xff0c;因为两者基本上共用同一代码库&#xff0c;除个别地方略有差异以外&…

C++要学到什么程度才能找到实习?

在考虑 C 学习到何种程度可以找到实习时&#xff0c;以下是一些具体的方向和建议。我这里有一套编程入门教程&#xff0c;不仅包含了详细的视频讲解&#xff0c;项目实战。如果你渴望学习编程&#xff0c;不妨点个关注&#xff0c;给个评论222&#xff0c;私信22&#xff0c;我…

HarmonyOS 应用开发之模型切换

本文介绍如何将一个FA模型开发的声明式范式应用切换到Stage模型&#xff0c;您需要完成如下动作&#xff1a; 工程切换&#xff1a;新建一个Stage模型的应用工程。 配置文件切换&#xff1a;config.json切换为app.json5和module.json5。 组件切换&#xff1a;PageAbility/Serv…