【PyTorch 攻略 (3/7)】线性组件、激活函数

news2025/1/24 8:45:04

一、说明

        神经网络是由层连接的神经元的集合。每个神经元都是一个小型计算单元,执行简单的计算来共同解决问题。它们按图层组织。有三种类型的层:输入层、隐藏层和输出层。每层包含许多神经元,但输入层除外。神经网络模仿人脑处理信息的方式。

二、神经网络的组件

  • 激活功能确定是否应该激活神经元。神经网络中发生的计算包括应用激活函数。如果一个神经元被激活,那么这意味着输入很重要。有不同种类的激活函数。选择使用哪个激活函数取决于您希望输出的内容。激活函数的另一个重要作用是向模型添加非线性。
    二进制用于将输出节点设置为 1(如果函数结果为正)和 0(如果函数结果为负)。
    Sigmoid 用于预测输出节点介于 0 和 1 之间的概率。
    Tanh 用于预测输出节点是否介于 1 和 -1 之间。用于分类用例。
    ReLU 用于在函数结果为负时将输出节点设置为 0,如果结果为正值,则保留结果值。
  • 权重会影响我们网络的输出接近预期输出值的程度。当输入进入神经元时,它被乘以权重值,结果输出被观察或传递到神经网络中的下一层。层中所有神经元的权重被组织成一个张量。
  • 偏差弥补了激活函数输出与其预期输出之间的差异。低偏差表明网络对输出形式做出更多的假设,而高偏差对输出形式做出的假设较少。

我们可以说,权重为 W 和偏差 b 的神经网络层的输出 y 计算为输入的总和乘以权重加上偏差。
x = ∑(权重∗输入)+ 偏置,其中 f(x) 是激活函数。

三、构建神经网络

        神经网络由对数据执行操作的层/模块组成。torch.nn 命名空间提供了构建自己的神经网络所需的所有构建块。PyTorch 中的每个模块都对 nn 进行子类化。模块。神经网络本身是一个模块,由其他模块(层)组成。这种嵌套结构允许轻松构建和管理复杂的架构。

        在这里,我们将构建一个神经网络来对 FashionMNIST 数据集中的图像进行分类。

%matplotlib inline
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

四、定义类

        我们通过子类化 nn.Module 来定义神经网络,并在 __init__ 中启动神经网络层。每个 nn.Module 子类都在转发方法中实现对输入数据的操作。

        我们的神经网络由以下内容组成:
- 具有 28x28 或 784 个特征/像素的输入层。
- 第一个线性模块接受输入的 1 个特征并将其转换为具有 784 个特征的隐藏层。
- ReLU 激活函数将应用于转换。
- 第二个线性模块从第一个隐藏层获取 512 个特征作为输入,并将其转换为具有 2 个特征的下一个隐藏层。
- ReLU 激活函数将应用于转换。
- 第 512 个线性模块从第 1 个隐藏层获取 512 个特征作为输入,并将其转换为具有 3(类数)的输出层。
- ReLU 激活函数将应用于转换。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

我们创建一个NeuralNetwork的实例,并将其移动到设备中,并打印其结构。

model = NeuralNetwork().to(device)
print(model)
NeuralNetwork(
  (flatten): Flatten()
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)

为了使用该模型,我们将输入数据传递给它。这将执行模型的前函数以及一些后台操作。但是,不要直接调用 model.forward()!在输入上调用模型会返回一个 10 维张量,其中包含每个类的原始预测值。我们通过 nn 的实例传递它来获得预测密度。软最大

X = torch.rand(1, 28, 28, device=device)
logits = model(X) 
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
Predicted class: tensor([2], device='cuda:0')

让我们分解一下FashionMNIST模型中的层。为了说明它,我们将取一个包含 3 张大小为 28x28 的图像的示例小批量,看看当我们通过网络传递它时会发生什么。

input_image = torch.rand(3,28,28)
print(input_image.size())
torch.Size([3, 28, 28])

4.1 nn.flatten

        我们初始化 nn。拼合图层以将每个 2D 28x28 图像转换为 784 像素值的连续数组(保持小批量尺寸(在 dim=0 时))。每个像素都传递到神经网络的输入层。

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())
torch.Size([3, 784])

4.2 nn.linear

        线性层是一个模块,它使用其存储的权重和偏差对输入应用线性变换。输入层中每个像素的灰度值将连接到隐藏层中的神经元进行变换计算,即权重*输入+偏差

layer1 = nn.Linear(in_features=28*28, out_features=20)
hidden1 = layer1(flat_image)
print(hidden1.size())
torch.Size([3, 20])

4.3 nn.relu

        非线性激活是在模型的输入和输出之间创建复杂映射的原因。它们被应用在线性变换之后引入非线性,帮助神经网络学习各种各样的现象。在这个模型中,我们使用 nn。线性层之间的 ReLU,但还有其他激活会在模型中引入非线性。

        ReLU 激活函数从线性层获取输出,并将负值替换为零。

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")
Before ReLU: tensor([[ 0.2190,  0.1448, -0.5783,  0.1782, -0.4481, -0.2782, -0.5680,  0.1347,
          0.1092, -0.7941, -0.2273, -0.4437,  0.0661,  0.2095,  0.1291, -0.4690,
          0.0358,  0.3173, -0.0259, -0.4028],
        [-0.3531,  0.2385, -0.3172, -0.4717, -0.0382, -0.2066, -0.3859,  0.2607,
          0.3626, -0.4838, -0.2132, -0.7623, -0.2285,  0.2409, -0.2195, -0.4452,
         -0.0609,  0.4035, -0.4889, -0.4500],
        [-0.3651, -0.1240, -0.3222, -0.1072, -0.0112, -0.0397, -0.4105, -0.0233,
         -0.0342, -0.5680, -0.4816, -0.8085, -0.3945, -0.0472,  0.0247, -0.3605,
         -0.0347,  0.1192, -0.2763,  0.1447]], grad_fn=<AddmmBackward>)


After ReLU: tensor([[0.2190, 0.1448, 0.0000, 0.1782, 0.0000, 0.0000, 0.0000, 0.1347, 0.1092,
         0.0000, 0.0000, 0.0000, 0.0661, 0.2095, 0.1291, 0.0000, 0.0358, 0.3173,
         0.0000, 0.0000],
        [0.0000, 0.2385, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2607, 0.3626,
         0.0000, 0.0000, 0.0000, 0.0000, 0.2409, 0.0000, 0.0000, 0.0000, 0.4035,
         0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0247, 0.0000, 0.0000, 0.1192,
         0.0000, 0.1447]], grad_fn=<ReluBackward0>)

4.4 nn.Sequential

        nnSequential是模块的有序容器。数据按照定义的相同顺序传递到所有模块。您可以使用顺序容器将快速网络(如seq_modules)组合在一起。

seq_modules = nn.Sequential(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Linear(20, 10)
)
input_image = torch.rand(3,28,28)
logits = seq_modules(input_image)

4.5 nn.software

        神经网络的最后一个线性层返回对数 — [-infty, infty] 中的
        原始值,这些值被传递给 nn。软最大模块。Softmax 激活函数用于计算神经网络输出的概率。它仅用于神经网络的输出层。结果缩放为表示模型对每个类的预测密度的值 [0,1]。dim 参数指示结果值总和必须为 1 的维度。概率最高的节点预测所需的输出。

softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)

五、模型参数

        神经网络中的许多层都是参数化的,即具有相关的权重和偏差,这些权重和偏差在训练期间进行了优化。子类化 nn.模块会自动跟踪模型对象中定义的所有字段,并使用模型的 parameters() 或 named_parameter() 方法访问所有参数。

print("Model structure: ", model, "\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

下一>> PyTorch 简介 (4/7)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1019335.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

虹科分享 | 谷歌Vertex AI平台使用Redis搭建大语言模型

文章来源&#xff1a;虹科云科技 点此阅读原文 基础模型和高性能数据层这两个基本组件始终是创建高效、可扩展语言模型应用的关键&#xff0c;利用Redis搭建大语言模型&#xff0c;能够实现高效可扩展的语义搜索、检索增强生成、LLM 缓存机制、LLM记忆和持久化。有Redis加持的大…

Docker启动Mysql容器并进行目录挂载

一、创建挂载目录 mkdir -p 当前层级下创建 mkdir -p mysql/data mkdir -p mysql/conf 进入到conf目录下创建配置文件touch hym.conf 并把配置文件hmy.conf下增加以下内容使用vim hym.conf即可添加(cv进去就行) Esc :wq 保存 [mysqld] skip-name-resolve character_set_…

设备树叠加层

设备树覆盖 设备树 (DT)是描述不可发现硬件的命名节点和属性的数据结构。内核&#xff08;例如 Android 中使用的 Linux 内核&#xff09;使用 DT 来支持 Android 设备使用的各种硬件配置。硬件供应商提供他们自己的设备树源 (DTS)文件&#xff0c;这些文件使用设备树编译器编…

UINT64整型数据在格式化时使用了不匹配的格式化符%d导致其他参数无法打印的问题排查

目录 1、问题描述 2、格式化函数内部解析待格式化参数的完整机制说明 2.1、传递给被调用函数的参数是通过栈传递的 2.2、格式化函数是如何从栈上找到待格式化的参数值&#xff0c;并完成格式化的&#xff1f; 2.3、字符串格式化符%s对应的异常问题场景说明 2.4、为了方便…

node 之 express 框架(初级)

一、express 热更新 1、安装扩展 npm install node-dev -D2、在根目录下的 package.json 文件中进行配置 3、之后的启动执行下面的命令即可 npm run dev二、mvc中的 模板引擎 1、ejs模板引擎的安装 npm install ejs -s2、在根目录下的app.js文件中配置 app.set(view engin…

我学编程全靠B站了,真香(第一期)

你好&#xff0c;我是Martin。 我是就读于B站大学2020届的Martin同学&#xff0c;反正我学习计算机真的是全靠 B 站了。 我是个刷视频狂魔&#xff0c;B站收藏夹里也收藏了很多编程类视频&#xff0c; 比如C/C、Go语言、操作系统、数据结构和算法、计算机网络、数据库、Pyth…

深入了解Python运算符和表达式:从基础到高级

&#x1f482; 个人网站:【工具大全】【游戏大全】【神级源码资源网】&#x1f91f; 前端学习课程&#xff1a;&#x1f449;【28个案例趣学前端】【400个JS面试题】&#x1f485; 寻找学习交流、摸鱼划水的小伙伴&#xff0c;请点击【摸鱼学习交流群】 Python运算符和表达式是…

JavaScript 学习笔记(基础)

其是一门跨平台、面向对象的脚本语言&#xff08;直译型语言&#xff09;&#xff0c;用来控制网页行为&#xff0c;能使网页产生交互效果&#xff01;下面以 JS 代称 JavaScript 引入HTML结构文件有两类方式&#xff1a; 内部脚本 行联式嵌入式外部脚本* 基本语法&#xff1…

npm发布vue3自定义组件库--方法二

npm发布vue3自定义组件库 创建项目 vue create test-ui自定义组件 创建自定义组件&#xff0c;组件名称根据你的需求来&#xff0c;最好一个组件一个文件夹&#xff0c;下图是我的示例。 src/components 组件和你写页面一样&#xff0c;所谓组件就是方便实用&#xff0c;不…

NotePad++ 在行前/行后添加特殊字符内容方法

我们在处理数据时&#xff0c;会遇到需要在每行数据前面、后面、开头、结尾添加各种不一样的字符 如果数据不多&#xff0c;我们可以自己手动的去添加&#xff0c;但如果达到了成百上千行&#xff0c;此时再机械的手动添加是不现实的 这里教给大家如何快速的在数据每行的前后…

华为云云耀云服务器L实例评测|cento7.9在线使用cloudShell下载rpm解压包安装mysql并开启远程访问

文章目录 ⭐前言⭐使用华为cloudShell连接远程服务器&#x1f496; 进入华为云耀服务器控制台&#x1f496; 选择cloudShell ⭐安装mysql压缩包&#x1f496; wget下载&#x1f496; tar解压&#x1f496; 安装步骤&#x1f496; 初始化数据库&#x1f496; 修改密码&#x1f4…

JavaCTF记录

Springmvcdemo 在没有提升权限之前&#xff0c;整个环境只有Cookie是可控的&#xff0c;并且提升权限也是要通过cookie来&#xff0c;先看看它对cookie做了什么&#xff0c;看一下过滤器 public void doFilter(ServletRequest request, ServletResponse response, FilterChai…

Python实现猎人猎物优化算法(HPO)优化随机森林回归模型(RandomForestRegressor算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 猎人猎物优化搜索算法(Hunter–prey optimizer, HPO)是由Naruei& Keynia于2022年提出的一种最新的…

Pyhton压缩JS代码

文章目录 1.安装依赖2.目录结构3.代码4.执行结果 1.安装依赖 pip install jsmin2.目录结构 3.代码 import jsmindef run(src_path, tgt_path):with open(src_path, "r", encodingutf-8) as input_file:with open(tgt_path, "w", encodingutf-8) as outpu…

外贸型CRM软件系统的作用

外贸企业在国际市场上面临着大量的竞争和风险&#xff0c;需要不断创新发展&#xff0c;提高自身的竞争力&#xff0c;但又受制于客户管理、业务效率、数据利用和风险控制等方面的不足。为了解决外贸企业面临的问题和挑战&#xff0c;外贸CRM系统应运而生。那么&#xff0c;什么…

面试(架构,网络)

java八股 treemap和linkdedhashmap区别&#xff0c;实现原理 https://blog.csdn.net/shidebin/article/details/126814905 架构 https://www.cnblogs.com/crazymakercircle/p/17197091.htmlhttps://www.cnblogs.com/crazymakercircle/p/17197091.html 羊了个羊https://www.c…

【LeetCode-简单题】1047. 删除字符串中的所有相邻重复项

文章目录 题目方法一&#xff1a;利用栈做匹配方法二&#xff1a;消消乐 题目 方法一&#xff1a;利用栈做匹配 class Solution {public String removeDuplicates(String s) {Deque<Character> deque new LinkedList<>();StringBuffer str new StringBuffer();fo…

封装七牛云存储工具类

文章目录 封装七牛云存储工具类&#xff08;为啥选择七牛云&#xff1f;当然是因为它能免费使用喽&#xff01;&#xff01;&#xff01;白嫖怪哈哈哈&#xff01;&#xff01;&#xff01;&#xff09;图片存储方案Java SDK操作七牛云封装工具类 封装七牛云存储工具类&#xf…

如何在 Excel 中求平方根

需要在 Excel 中求一个数字的平方根吗&#xff1f;使用几个内置的 Excel 函数和公式可以轻松计算平方根。在本分步指南中&#xff0c;您将学习在 Excel 中计算平方根的 5 种不同方法&#xff0c;包括使用 SQRT 函数、POWER 函数、指数公式、VBA 代码和 Power Query。跟随教程&a…

我学编程全靠B站了,真香-国外篇(第三期)

你好&#xff0c;我是Martin。 今天来点猛料&#xff0c;给大家推荐点我的压箱收藏-国外知名大学的公开课。 我推荐的不多&#xff0c;本着少就是多的原则&#xff0c;只给大家推荐我看过最好的五门视频&#xff0c;主要是来自两所国外高校&#xff1a;MIT美国麻省理工、CMU卡…