【机器学习】pytorch 常用函数解析

news2024/11/20 8:48:12

目录

一、基本函数介绍

1.1 nn.Module 类

1.2 nn.Embedding

1.3 nn.LSTM

1.4 nn.Linear

1.5 nn.CrossEntropyLoss

1.6 torch.save

1.7 torch.load

1.8 nn.functional

1.9 nn.functional.softmax


本文主要对 pytorch 中用到的函数进行介绍,本文会不断更新~

一、基本函数介绍

1.1 nn.Module 类

我们自己定义模型的时候,通常继承 nn.Module 类,然后重写 nn.Module 中的方法,nn.Module 的主要方法如下所示。

class Module(object):
    def __init__(self):
    def forward(self, *input):
 
    def add_module(self, name, module):
    def cuda(self, device=None):
    def cpu(self):
    def __call__(self, *input, **kwargs):
    def parameters(self, recurse=True):
    def named_parameters(self, prefix='', recurse=True):
    def children(self):
    def named_children(self):
    def modules(self):  
    def named_modules(self, memo=None, prefix=''):
    def train(self, mode=True):
    def eval(self):
    def zero_grad(self):
    def __repr__(self):
    def __dir__(self):
    #......还有一部分,此处未列出

自定义模型一般重写 __init__ 和 forward 函数。

1.2 nn.Embedding

nn.Embedding(num_embeddings, embedding_dim)

(1)参数

num_embeddings :嵌入字典的大小,也可以理解为该模型可以表示词的数量;

embedding_size:表示嵌入向量的维度。

nn.Embedding 层的本质是一个查找表,它将输入的每个索引映射到一个固定大小的向量,可以理解为每个词都有一个固定的向量。这个映射表在初始化时会随机生成,然后在训练过程中通过反向传播进行优化。

(2)主要步骤

初始化:在初始化时,nn.Embedding 会创建一个大小为 (num_embeddings, embedding_dim)的权重矩阵。这些权重是嵌入层的参数,会在训练过程中更新;

前向传播:在前向传播过程中,nn.Embedding 层会将输入的索引映射到权重矩阵的相应行,从而得到对应的嵌入向量;

反向传播:在训练过程中,嵌入层的权重矩阵会根据损失函数的梯度进行更新。这使得嵌入向量能够捕捉到输入的语义信息。

(3)nn.Embedding 原理

nn.Embedding 的核心是一个查找表,其大小为 (num_embeddings,embedding_dim),每一行代表一个词或索引的嵌入向量。 在向量化时,输入的索引被用来查找嵌入向量,假设输入是 [1, 2, 3],则输出是权重矩阵(num_embeddings,embedding_dim)中第 1、2、3 行的向量。

下面通过一个例子进行说明。

import torch
import torch.nn as nn

# 创建 Embedding 层
num_embeddings = 10  # 词汇表大小
embedding_dim = 3    # 嵌入向量的维度
embedding_layer = nn.Embedding(num_embeddings, embedding_dim)

# 输入
input_indices = torch.LongTensor([1, 2, 3, 4])

# 转换为嵌入向量
output_vectors = embedding_layer(input_indices)

# 输出
print("input_indices:", input_indices)
print("output_vectors:", output_vectors)

输出如下所示。

(chat6b) D:\code\ChatGLM-6B-main>python test.py
input_indices: tensor([1, 2, 3, 4])
output_vectors: tensor([[-0.3269, -1.2620,  0.0695],
        [-1.6919, -1.6591, -0.7417],
        [ 2.0479,  0.9768,  1.4318],
        [-0.7075,  1.1718,  0.7530]], grad_fn=<EmbeddingBackward0>)

(chat6b) D:\code\ChatGLM-6B-main>

输出一共包含四个向量,每行表示一个。

1.3 nn.LSTM

后续更新~

1.4 nn.Linear

nn.Linear 是神经网络的线性层,可以看作是通过一个二维矩阵做了一个转换。

torch.nn.Linear(in_features,  # 输入的神经元个数
                out_features, # 输出神经元个数
                bias=True     # 是否包含偏置
                )

nn.Linear 对输入执行线性变换,如下所示。

其中,X 表示输入,Y 表示输出,b 为偏置。

下面来看一个例子。

import torch
from torch import nn


input = torch.Tensor([1, 2, 3]) # 样本有 3 个特征
model = nn.Linear(3, 2) # 输入特征数为 3,输出特征数为 2
print("model = ", model)
# nn.Linear 权重
for param in model.parameters():
    print(param)

output = model(input)
print(output)

输出如下所示。

model =  Linear(in_features=3, out_features=2, bias=True)
Parameter containing:
tensor([[-0.4270,  0.0396,  0.2899],
        [-0.4481,  0.4071,  0.4366]], requires_grad=True)
Parameter containing:
tensor([-0.1091,  0.3018], requires_grad=True)
tensor([0.4128, 1.9777], grad_fn=<ViewBackward0>)

 X(1x3)= [1,2,3], W(3x2) = [[-0.4270,  0.0396,  0.2899], [-0.4481,  0.4071,  0.4366]]的转置,b = [-0.1091,  0.3018],可以手动计算最后的结果,例如:0.4128 = -0.4270 * 1 + 0.0396 * 2 + 0.2899*3 - 0.1091,同理也可以计算 1.9777。

1.5 nn.CrossEntropyLoss

交叉熵(Cross-Entropy)是一种用于比较真实标签和预测标签概率之间差异的度量,交叉熵通常用作损失函数,用于衡量模型预测与真实标签之间的差异,尤其在分类任务中广泛使用。

交叉熵越小,模型预测越准确。当模型的预测与真实标签完全一致时,交叉熵达到最小值为 0。

import torch
import torch.nn as nn
from torch.nn.functional import one_hot

output = torch.randn(4, 3)  # 模型预测,4 个样本,3 分类
print('output:\n', output)

target = torch.tensor([1, 2, 0, 1])  # 真实标签值
target1 = target
# 实际上不需要转换为 one_hot,这里测试证明了这一点
target = one_hot(target, num_classes=3)
target = target.to(dtype=torch.float)
crossentropyloss = nn.CrossEntropyLoss()
output_loss = crossentropyloss(output, target)
output_loss1 = crossentropyloss(output, target1)

print('output_loss:\n', output_loss)
print('output_loss1:\n', output_loss1)

 顺便测试了下是否需要转换为 one_hat。

1.6 torch.save

torch.save() 的主要作用就是将 PyTorch 对象(如模型、张量等)保存到磁盘上,以文件的形式进行存储。如果想使用训练后的模型,从磁盘上加载即可。

torch.save(model,保存路径)  # 保存整个模型
torch.save(model.state_dict(), 保存路径) # 只保存模型参数

 CrossEntropyLoss() 损失函数结合了 nn.LogSoftmax() 和 nn.NLLLoss() 两个函数。它在做分类训练的时候是非常有用的。

1.7 torch.load

 torch.load() 函数用于加载磁盘上模型文件。 

 torch.load(模型路径)

1.8 nn.functional

nn.functional 是 PyTorch 中一个重要的模块,它包含了许多用于构建神经网络的函数。与 nn.Module 不同,nn.functional 中的函数不具有可学习的参数。

这些函数通常用于执行各种非线性操作、损失函数、激活函数等。 这个模块的主要优势是它的计算效率和灵活性,因为它允许你以函数的方式直接调用这些操作,而不需要创建额外的层。

1.9 nn.functional.softmax

softmax 有两种形式。

torch.nn.Softmax(input, dim)
torch.nn.functional.softmax(input, dim)

下面主要对 torch.nn.functional.softmax 进行介绍。 

对 n 维输入张量运用 softmax 函数,将张量的每个元素缩放到(0,1)区间且和为1。

softmax(input, dim=None, _stacklevel=3, dtype=None)

主要参数:

input : 输入的张量;

dim : 指明维度,dim=0表示按列计算;dim=1表示按行计算。默认dim的方法已经弃用了,最好声明dim,否则会警告。

softmax 公式如下所示。

下面来看一个例子。

import torch
import torch.nn.functional as F
 
input = torch.Tensor([[1, 2, 3, 4],[1, 2, 3, 4]])
 
output1 = F.softmax(input, dim=0) #对每一列进行softmax
print(output1)
 
output2 = F.softmax(input, dim=1) #对每一行进行softmax
print(output2)

 输出如下所示。

tensor([[0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000]])
tensor([[0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439]])

分别对输入张量的列和行进行了 softmax。 

后续更新:torch.randn、torch.tensor、one_hot、torch.LongTensor

参考链接:

[1] Pytorch nn.Linear()的基本用法与原理详解及全连接层简介_nn.linear()作用-CSDN博客

[2] pytorch教程之nn.Module类详解——使用Module类来自定义模型-CSDN博客

[3] torch.nn - PyTorch中文文档 

[4] pytorch nn.Embedding 用法和原理_pytorch nn.embedding 设置初始化函数-CSDN博客

[5] Pytorch nn.Linear()的基本用法与原理详解及全连接层简介_nn.linear()作用-CSDN博客 

[6] https://www.cnblogs.com/wanghui-garcia/p/10675588.html 

[7] PyTorch `nn.functional` 模块详解:探索神经网络的魔法工具箱_torch.nn.functional-CSDN博客  

[8] Pytorch CrossEntropyLoss() 原理和用法详解-CSDN博客 

[9] https://www.cnblogs.com/peixu/p/13194801.html 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1954873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Redis进阶】主从复制

1. 主从结构引入 在分布式系统中&#xff0c;涉及到一个严重问题&#xff1a;单点问题 即如果某个服务器程序只有一个节点&#xff08;单台机器提供服务&#xff09;&#xff0c;就会出现以下两个问题&#xff1a; 可用性问题&#xff0c;如果这台机器挂了&#xff0c;意味着…

Github 2024-07-27开源项目日报 Top10

根据Github Trendings的统计,今日(2024-07-27统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量非开发语言项目2C++项目2C项目2TypeScript项目1JavaScript项目1Java项目1Python项目1C#项目1免费编程学习平台:freeCodeCamp.org 创建周期:33…

jQuery入门(一)

一、JQuery介绍 - jQuery 是一个 JavaScript 库。 - 所谓的库&#xff0c;就是一个 JS 文件&#xff0c;里面封装了很多预定义的函数&#xff0c;比如获取元素&#xff0c;执行隐藏、移动等&#xff0c;目的就 是在使用时直接调用&#xff0c;不 需要再重复定义&#xff0c;这…

iPhone 在 App Store 中推出的 PC 模拟器 UTM SE

PC 模拟器是什么&#xff1f;PC 模拟器是一种软件工具&#xff0c;它模拟不同硬件或操作系统环境&#xff0c;使得用户可以在一台 PC 上运行其他平台的应用程序或操作系统。通过 PC 模拟器&#xff0c;用户可以在 Windows 电脑上体验 Android 应用、在 Mac 电脑上运行 Windows …

Python如何获取终端尺寸?

os.get_terminal_size()&#xff0c;无差别获取当前终端长宽&#xff0c;让你为所欲为。 (笔记模板由python脚本于2024年07月27日 08:30:53创建&#xff0c;本篇笔记适合喜欢钻研的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;https://www.python.org/ Fre…

使用命名管道的通信程序, 加入了日志系统

文章目录 日志系统通信程序运行效果 日志系统 // log.hpp #pragma once #include <time.h> #include <iostream> #include <stdio.h> #include <string> #include <stdarg.h> #include <sys/types.h> #include <sys/stat.h> #inclu…

软设之数据库关系代数

数据库关系代数 基本概念 元祖行&#xff1a;水平方向上每一行为一条记录&#xff0c;这个记录对应1个实体。一般称为元祖&#xff0c;元祖行或者记录 属性列&#xff1a;垂直方向上每一列为一个属性&#xff0c;一般称为属性列&#xff0c;字段等。关系表达式中可以用列序号…

又要起飞,浏览器居然都可以本地 OCR 啦

前言 PaddleOCR&#xff0c;这是一个由百度开发的开源 OCR&#xff08;Optical Character Recognition&#xff0c;光学字符识别&#xff09;工具&#xff0c;它可以用于从图像中识别文本。 PaddleOCR支持多种语言的文本识别&#xff0c;并且能够处理多种场景下的图像。 现在…

【Web开发手礼】探索Web开发的魅力(十二)-Vue(2)用户动态页面

前言 主要介绍了用vue框架创建用户动态页面的具体过程&#xff0c;可以帮助学习vue框架的基本知识&#xff01;&#xff01;&#xff01;&#xff01; 用户动态页面 用户信息 用户头像 通过 Bootstrap 所提供的 .img-fluid 类让图片支持响应式布局。其原理是将 max-width: 10…

Java面试八股之Spring boot的自动配置原理

Spring boot的自动配置原理 Spring Boot 的自动配置原理是其最吸引人的特性之一&#xff0c;它大大简化了基于 Spring 框架的应用程序开发。以下是 Spring Boot 自动配置的基本原理和工作流程&#xff1a; 1. 启动类上的注解 Spring Boot 应用通常会在主类上使用 SpringBoot…

ZBrush入门使用介绍——4、笔刷选项说明

大家好&#xff0c;我是阿赵。   这次来看看ZBrush的笔刷的选项用法。 一、选择笔刷 点击笔刷&#xff0c;可以打开笔刷选择面板。 在最上面的Quick Pick&#xff0c;有最近使用过的笔刷&#xff0c;可以快速的选择。下面有很多可以选择的笔刷。但由于笔刷太多&#xff0c;…

AJAX之基础知识

目录 AJAX入门及axios使用什么是AJAX怎么用AJAX 认识URL协议域名资源路径URL查询参数 查询参数URL查询参数axios查询参数 常用请求方法axios请求配置 axios错误处理HTTP协议请求报文请求报文-错误排查响应报文HTTP响应状态码 form-serialize插件 AJAX入门及axios使用 什么是AJ…

【Python机器学习】决策树的构造——信息增益

决策树是最经常使用的数据挖掘算法。它之所以如此流行&#xff0c;一个很重要的原因就是不需要了解机器学习的知识&#xff0c;就能搞明白决策树是如何工作的。 决策树的优缺点&#xff1a; 优点&#xff1a;计算复杂度不高&#xff0c;输出结果易于理解&#xff0c;对中间值的…

RabbitMq手动ack的超简单案例+Confirm和Return机制的配置和使用

最简单的例子 先简单介绍一下这三个方法 basicAck 表示确认成功&#xff0c;使用此方法后&#xff0c;消息会被rabbitmq broker删除 basicNack 表示失败确认&#xff0c;一般在消费消息业务异常时用到此方法&#xff0c;可以将消息重新投递入队列 basicReject 拒绝消息&am…

Chainlit一个快速构建成式AI应用的Python框架,无缝集成与多平台部署

概述 Chainlit 是一个开源 Python 包&#xff0c;用于构建和部署生成式 AI 应用的开源框架。它提供了一种简单的方法来创建交互式的用户界面&#xff0c;这些界面可以与 LLM&#xff08;大型语言模型&#xff09;驱动的应用程序进行通信。Chainlit 旨在帮助开发者快速构建基于…

全网最适合入门的面向对象编程教程:25 类和对象的 Python 实现-Python 判断输入数据类型

全网最适合入门的面向对象编程教程&#xff1a;25 类和对象的 Python 实现-Python 判断输入数据类型 摘要&#xff1a; 本文主要介绍了在使用 Python 面向对象编程时&#xff0c;如何使用 type 函数、isinstance 函数和正则表达式三种方法判断用户输入数据类型&#xff0c;并对…

PWA(渐进式网页应用)方式实现TodoList桌面应用

参考&#xff1a; https://cloud.tencent.com/developer/article/2322236 todlist网页参考&#xff1a; https://blog.csdn.net/weixin_42357472/article/details/140657576 实现在线网页当成app应用&#xff1a; 一个 PWA 应用首先是一个网页, 是通过 Web 技术编写出的一个网…

如何全面提升架构设计的质量?

当我们从可扩展、高可用、高性能等角度设计出来架构的时候&#xff0c;我们如何优化架构呢&#xff1f;就需要从成本、安全、测试等角度进行优化。 如何设计更好的架构 - 步骤 成本 低成本复杂度本质 低成本手段和应用 低成本的主要应用场景 安全 安全性复杂度本质 架构安全…

大语言模型系列-Transformer:深入探索与未来展望

大家好&#xff0c;我是一名测试开发工程师&#xff0c;已经开源一套【自动化测试框架】和【测试管理平台】&#xff0c;欢迎大家联系我&#xff0c;一起【分享测试知识&#xff0c;交流测试技术】 Transformer模型自其问世以来&#xff0c;便迅速在自然语言处理领域崭露头角&a…