Pytorch深度学习-----神经网络之线性层用法

news2024/11/24 15:02:09

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)


文章目录

  • 系列文章目录
  • 一、线性层是什么?
    • 1.官网解释
    • 2.nn.Linear函数参数介绍
  • 二、实战演示
    • 1.将CIFAR10图片数据集进行线性变换


一、线性层是什么?

线性层是深度学习中常用的一种基本层类型。它也被称为全连接层或仿射层。线性层的作用是将输入数据与权重矩阵相乘,然后加上偏置向量,最后输出一个新的特征表示。

具体来说,线性层可以表示为 Y = XW + b,其中 X 是输入数据W 是权重矩阵b 是偏置向量Y 是输出结果。这个过程可以看作是对输入数据进行线性变换的操作。

1.官网解释

官网访问:LINEAR
如下图所示
在这里插入图片描述
在这里插入图片描述
由此可见,每一层的某个神经元的值都为前一层所有神经元的值的总和。

2.nn.Linear函数参数介绍

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

其中最重要的三个参数为in_features, out_features, bias

in_features, 表示输入的特征值大小,即输入的神经元个数
out_features,表示输出的特征值大小,即经过线性变换后输出的神经元个数
bias,表示是否添加偏置

二、实战演示

在这里插入图片描述
预定要的in_features为1,1,x形式
out_features为1,1,y的形式

1.将CIFAR10图片数据集进行线性变换

代码如下:

import torch
import torchvision
from torch.utils.data import DataLoader

# 准备数据
test_set = torchvision.datasets.CIFAR10("dataset",
                                        train=False,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 加载数据集
dataloader = DataLoader(test_set,batch_size=64)

# 查看输入的通道数
# for data in dataloader:
#     imgs, target = data
#     print(imgs.shape)  # torch.Size([64, 3, 32, 32])
#     # 将img进行reshape成1,1,x的形式
#     input = torch.reshape(imgs,(1,1,1,-1)) # 每次一张图,1通道,1*自动计算x
#     print(input.shape) # torch.Size([1, 1, 1, 196608])

# 搭建神经网络,设置预定的输出特征值为10
class Lgl(torch.nn.Module):
    def __init__(self):
        super(Lgl, self).__init__()
        self.linear1 = torch.nn.Linear(196608,10)  # 输入数据的特征值196608,输出特征值10
    def forward(self, input):
        output = self.linear1(input)
        return output
# 实例化
l = Lgl()
# 进行线性操作

for data in dataloader:
    imgs, target = data
    print(imgs.shape)  # torch.Size([64, 3, 32, 32])
    # 将img进行reshape成1,1,x的形式
    input = torch.reshape(imgs,(1,1,1,-1)) # 每次一张图,1通道,1*自动计算x
    output = l(input)
    print(output.shape) # torch.Size([1, 1, 1, 10])


原先的图片shape:torch.Size([64, 3, 32, 32])
reshape后的图片shape:torch.Size([1, 1, 1, 196608])
经过线性后的图片shape:torch.Size([1, 1, 1, 10])
原先的图片shape:torch.Size([64, 3, 32, 32])
reshape后的图片shape:torch.Size([1, 1, 1, 196608])
经过线性后的图片shape:torch.Size([1, 1, 1, 10])
……

除了使用reshape后,还可以使用torch.flatten()进行修改尺寸,将其自动修改为一维。
torch.flatten(input, start_dim=0, end_dim=- 1)
将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor

修改成flatten后代码如下

import torch
import torchvision
from torch.utils.data import DataLoader

# 准备数据
test_set = torchvision.datasets.CIFAR10("dataset",
                                        train=False,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 加载数据集
dataloader = DataLoader(test_set,batch_size=64)

# 查看输入的通道数
# for data in dataloader:
#     imgs, target = data
#     print(imgs.shape)  # torch.Size([64, 3, 32, 32])
#     # 将img进行reshape成1,1,x的形式
#     input = torch.reshape(imgs,(1,1,1,-1)) # 每次一张图,1通道,1*自动计算x
#     print(input.shape) # torch.Size([1, 1, 1, 196608])

# 搭建神经网络,设置预定的输出特征值为10
class Lgl(torch.nn.Module):
    def __init__(self):
        super(Lgl, self).__init__()
        self.linear1 = torch.nn.Linear(196608,10)  # 输入数据的特征值196608,输出特征值10
    def forward(self, input):
        output = self.linear1(input)
        return output
# 实例化
l = Lgl()
# 进行线性操作

for data in dataloader:
    imgs, target = data
    print(f"原先的图片shape:{imgs.shape}")  # torch.Size([64, 3, 32, 32])
    # 将img进行reshape成1,1,x的形式
    input = torch.flatten(imgs) # 每次一张图,1通道,1*自动计算x
    print(f"flatten后的图片shape:{input.shape}")
    output = l(input)
    print(f"经过线性后的图片shape:{output.shape}") # torch.Size([1, 1, 1, 10])


原先的图片shape:torch.Size([64, 3, 32, 32])
flatten后的图片shape:torch.Size([196608])
经过线性后的图片shape:torch.Size([10])
原先的图片shape:torch.Size([64, 3, 32, 32])
flatten后的图片shape:torch.Size([196608])
经过线性后的图片shape:torch.Size([10])
……

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/825637.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[JavaScript游戏开发] 绘制Q版地图、键盘上下左右地图场景切换

系列文章目录 第一章 2D二维地图绘制、人物移动、障碍检测 第二章 跟随人物二维动态地图绘制、自动寻径、小地图显示(人物红点显示) 第三章 绘制冰宫宝藏地图、人物鼠标点击移动、障碍检测 第四章 绘制Q版地图、键盘上下左右地图场景切换 文章目录 系列文章目录前言一、本章节…

Spring?Boot项目如何优雅实现Excel导入与导出功能

目录 背景EasyExcel 问题分析与解决Spring Boot Excel 导入与导出 依赖引入Excel 导入 基本导入功能进阶导入功能Excel 导出 Excel 导入参数校验 开启校验 校验规则定义 Bean Validation 定义校验规则ExcelValidator 接口定义校验规则校验结果接收 异常捕获接收校验结果contro…

无涯教程-Lua - 垃圾回收

Lua使用自动内存管理,该管理使用基于Lua内置的某些算法的垃圾回收。 垃圾收集器暂停 垃圾收集器暂停用于控制垃圾收集器之前需要等待多长时间; Lua的自动内存管理再次调用它。值小于100意味着Lua将不等待下一个周期。同样,此值的较高值将导…

《吐血整理》高级系列教程-吃透Fiddler抓包教程(21)-如何使用Fiddler生成Jmeter脚本-上篇

1.简介 我们知道Jmeter本身可以录制脚本,也可以通过BadBoy,BlazeMeter等工具进行录制,其实Fiddler也可以录制Jmter脚本(而且有些页面,由于安全设置等原因,使用Jmeter直接无法打开录制时,这时就…

PyQT模块、类、控件介绍

最近在搞一些基于PyQT的开发,开发过程中一直对PyQT相关模块、类、控件比较模糊,于是花了一些力气,去收集和整理了一下PyQT的一些基础,希望对大家有帮助! PyQT模块 QtCore模块涵盖了包的核心的非GUI功能,此模…

windows 安装 Mysql 5.7 和8.0

下载链接 https://dev.mysql.com/downloads/mysql/

【echarts饼图】legend显示data中的name和value

效果图: legend自定义显示格式: legend: {formatter: function (name) {let v;optionCheck.series.data.forEach((item) > {if (item.name name) {v item.value;}});return name v;},},全部配置项: const optionCheck reactive(…

Shell脚本学习-MySQL单实例和多实例启动脚本

已知MySQL多实例启动命令为: mysqld_safe --defaults-file/data/3306/my.cnf & 停止命令为: mysqladmin -uroot -pchang123 -S /data/3306/mysql.sock shutdown 请完成mysql多实例的启动脚本的编写: 问题分析: 要想写出脚…

BES 平台 SDK之提示音的添加

本文章是基于BES2700 芯片,其他BESxxx 芯片可做参考,如有不当之处,欢迎评论区留言指出。仅供参考学习用! BES 平台 SDK之按键的配置_谢文浩的博客-CSDN博客 关于系统按键简介可参考上一篇文章。链接如上所示! 一&am…

工作流管理软件的好处:提升效率、优化流程的利器

一旦您投资了工作流管理系统,就没有回头路了。这是让您的员工满意的完美公式,同时确保所有流程以最高效率和及时完成。事实证明,这些实用的工作流程管理工具对为其客户提供基于知识的专业服务的组织有益(在我们的专业服务指南中了…

【3维视觉】3D空间常用算法(点到直线距离、面法线、二面角)

3D空间点到直线的距离 3D空间点到直线的距离 3D空间的曲率 三维空间有三个基本元素,点,线,面。那么曲率是如何定义的呢? 点的曲率? 线的曲率? 面的曲率? 法曲率 设曲面上的曲线在某一点处的切…

【Spring Boot】请求参数传json数组,后端采用(pojo)新增案例(103)

请求参数传json数组,后端采用(pojo)接收的前提条件: 1.pom.xml文件加入坐标依赖:jackson-databind 2.Spring Boot 的启动类加注解:EnableWebMvc 3.Spring Boot 的Controller接受参数采用:Reque…

第一个 vue-cli 项目

一、什么是 vue-cli vue-cli 官方提供的一个脚手架,用于快速生成一个 vue 的项目模板;预先定义好的目录结构及基础代码,就好比咱们在创建 Maven 项目时可以选择创建一个骨架项目,这个骨架项目就是脚手架,我们的开发更加…

为什么数字孪生和GIS高度互补?它们是如何实现互补的?

在数字化时代,数字孪生和GIS作为两项重要技术,它们的融合正日益受到人们的关注和认可。数字孪生是将实体世界与数字世界紧密结合的技术,可以创建实时的虚拟副本,对物理系统进行模拟、优化和预测。而GIS则是用于收集、管理、分析和…

js将当前时间转换成标准的年月日

直接上代码了: /*** * param e 转换成标准的年月日进行拆分* returns */changeCreationtime(e:any) {let year e.getFullYear(),month (e.getMonth() 1) > 9 ? (e.getMonth() 1) : 0 (e.getMonth() 1),day e.getDate() > 9 ? e.getDate() : 0 e.get…

__block的深入研究

__block可以用于解决block内部无法修改auto变量值的问题 __block不能修饰全局变量、静态变量(static) 编译器会将__block变量包装成一个对象 调用的是,从__Block_byref_a_0的指针找到 a所在的内存,然后修改值 第一层拷贝&…

VLAN介绍

目录 VLAN的特点: VLAN的好处: VLAN的实现原理 VLAN标签 VLAN的划分方式 接口划分VLAN--接口类型 Access接口 Trunk接口 Hybrid接口 实现VLAN之间通信 使用路由器物理接口 使用子接口 使用三层交换机的VLANIF接口 配置 VLANIF的转发流程 三层交换机参与下的三层…

【图解】Mask R-CNN 架构

Mask R-CNN 是一种自顶向下(top-down)的姿态估计模型,它是在 Faster R-CNN [44] 这个目标检测框架的基础上扩展而来的。目标检测是指从图像中检测出不同类别的物体,并且输出它们的边界框(bounding box)。 …

exp/imp选项说明

1、exp选项 2、imp选项 3、举例 (1)、imp system/manager filetank logtank fromuser(seapark,amy) touser(seapark1, amy1);(2)、imp system/manager file(paycheck_1,paycheck_2,paycheck_3,paycheck_4) logpaycheck.log filesize1G fully;(3)、imp system/manager fileseap…

【css】解决元素浮动溢出问题

如果一个元素比包含它的元素高&#xff0c;并且它是浮动的&#xff0c;它将“溢出”到其容器之外&#xff1a;然后可以向包含元素添加 overflow: auto;&#xff0c;来解决此问题&#xff1a; 代码&#xff1a; <!DOCTYPE html> <html> <head> <style>…