Pytorch nn.Linear()的基本用法与原理详解及全连接层简介

news2024/11/25 14:33:21

主要引用参考:
https://blog.csdn.net/zhaohongfei_358/article/details/122797190
https://blog.csdn.net/weixin_43135178/article/details/118735850

nn.Linear的基本定义

nn.Linear定义一个神经网络的线性层,方法签名如下:

torch.nn.Linear(in_features, # 输入的神经元个数
           out_features, # 输出神经元个数
           bias=True # 是否包含偏置
           )

Linear其实就是对输入 X n × i X_{n\times i} Xn×i执行了一个线性变换,即:
Y n × o = X n × i W i × o + b Y_{n\times o}=X_{n\times i}W_{i\times o}+b Yn×o=Xn×iWi×o+b
其中 W W W是模型想要学习的参数, W W W的维度为 W i × o W_{i\times o} Wi×o,b是o维的向量偏置,n为输入向量的行数(例如,你想一次输入10个样本,即batch_size为10,则n=10),i为输入神经元的个数(例如你的样本特征为5,则i=5),o为输出神经元的个数。

示例:

from torch import nn
import torch

model = nn.Linear(2, 1) # 输入特征数为2,输出特征数为1
input = torch.Tensor([1, 2]) # 给一个样本,该样本有2个特征(这两个特征的值分别为1和2)
output = model(input)
output
tensor([-1.4166], grad_fn=<AddBackward0>)

我们的输入为[1,2],输出了[-1.4166]。可以查看模型参数验证一下上述的式子:

# 查看模型参数
for param in model.parameters():
    print(param)
Parameter containing:
tensor([[ 0.1098, -0.5404]], requires_grad=True)
Parameter containing:
tensor([-0.4456], requires_grad=True)

可以看到,模型有3个参数,分别为两个权重和一个偏执。计算可得:
y = [ 1 , 2 ] ∗ [ 0.1098 , − 0.5404 ] T − 0.4456 = − 1.4166 y=[1,2]*[0.1098,-0.5404]^T-0.4456=-1.4166 y=[1,2][0.1098,0.5404]T0.4456=1.4166


实战

假设我们的一次输入三个样本A,B,C(即batch_size为3),每个样本的特征数量为5:

A: [0.1,0.2,0.3,0.3,0.3]
B: [0.4,0.5,0.6,0.6,0.6]
C: [0.7,0.8,0.9,0.9,0.9]

则我们的输入向量 X 3 × 5 X_{3\times 5} X3×5为:

X = torch.Tensor([
    [0.1,0.2,0.3,0.3,0.3],
    [0.4,0.5,0.6,0.6,0.6],
    [0.7,0.8,0.9,0.9,0.9],
])
X
tensor([[0.1000, 0.2000, 0.3000, 0.3000, 0.3000],
        [0.4000, 0.5000, 0.6000, 0.6000, 0.6000],
        [0.7000, 0.8000, 0.9000, 0.9000, 0.9000]])

定义线性层,我们的输入特征为5,所以in_feature=5,我们想让下一层的神经元个数为10,所以out_feature=10,则模型参数为: W 5 × 10 W_{5\times 10} W5×10

model = nn.Linear(in_features=5, out_features=10, bias=True)

经过线性层,其实就是做了一件事,即:
Y 3 × 10 = X 3 × 5 W 5 × 10 + b Y_{3\times 10}=X_{3\times 5}W_{5\times 10}+b Y3×10=X3×5W5×10+b
具体表示为:
[ Y 00 Y 01 ⋯ Y 08 Y 09 Y 10 Y 11 ⋯ Y 18 Y 19 Y 20 Y 21 ⋯ Y 28 Y 29 ] = [ X 00 X 01 X 02 X 03 X 04 X 10 X 11 X 12 X 13 X 14 X 20 X 21 X 22 X 23 X 23 ] [ W 00 W 01 ⋯ W 08 W 09 W 10 W 11 ⋯ W 18 W 19 W 20 W 21 ⋯ W 28 W 29 W 30 W 31 ⋯ W 38 W 39 W 40 W 41 ⋯ W 48 W 49 ] + b \begin{equation} \left[ \begin{array}{ccc} Y_{00} & Y_{01} &\cdots & Y_{08} &Y_{09} \\ Y_{10} & Y_{11} &\cdots & Y_{18} &Y_{19} \\ Y_{20} & Y_{21} &\cdots & Y_{28} &Y_{29} \end{array} \right] =\left[ \begin{array}{ccc} X_{00} & X_{01} &X_{02} & X_{03} &X_{04} \\ X_{10} & X_{11} &X_{12} & X_{13} &X_{14} \\ X_{20} & X_{21} &X_{22} & X_{23} &X_{23} \end{array}\nonumber \right] \left[ \begin{array}{ccc} W_{00} & W_{01} &\cdots & W_{08} &W_{09} \\ W_{10} & W_{11} &\cdots & W_{18} &W_{19} \\ W_{20} & W_{21} &\cdots & W_{28} &W_{29} \\ W_{30} & W_{31} &\cdots & W_{38} &W_{39} \\ W_{40} & W_{41} &\cdots & W_{48} &W_{49} \\ \end{array} \right] +b \end{equation}\nonumber Y00Y10Y20Y01Y11Y21Y08Y18Y28Y09Y19Y29 = X00X10X20X01X11X21X02X12X22X03X13X23X04X14X23 W00W10W20W30W40W01W11W21W31W41W08W18W28W38W48W09W19W29W39W49 +b

个人的理解:比如 X X X第一行和 W W W矩阵的第一列相乘就相当于对样本A做了全局卷积,最后得到了1个特征,因为 W W W有10列,所以最后得到10个特征,也就是把5个特征转变为了10个特征。

其中 X i . X_i. Xi.就表示第i个样本, W . j W_{.j} W.j表示所有输入神经元到第j个输出神经元的权重。
在这里插入图片描述

注意:这里图有点问题,应该是 W 00 , W 01 , W 02 , . . . , W 07 , W 08 , W 09 W_{00}, W_{01}, W_{02}, ..., W_{07}, W_{08},W_{09} W00,W01,W02,...,W07,W08,W09(我没觉得图有问题)

因为有三个样本,所以相当于依次进行了三次 Y 3 × 10 = X 3 × 5 W 5 × 10 + b Y_{3\times 10}=X_{3\times 5}W_{5\times 10}+b Y3×10=X3×5W5×10+b,然后再将三个 Y 1 × 10 Y_{1\times 10} Y1×10叠在一起

经过线性层后,我们最终的到了3×10维的矩阵,即 输入3个样本,每个样本维度为5,输出为3个样本,将每个样本扩展成了10维

model(X).size()
torch.Size([3, 10])

全连接层

概述
全连接层 Fully Connected Layer 一般位于整个卷积神经网络的最后,负责将卷积输出的二维特征图转化成一维的一个向量,由此实现了端到端的学习过程(即:输入一张图像或一段语音,输出一个向量或信息)。全连接层的每一个结点都与上一层的所有结点相连因而称之为全连接层。由于其全相连的特性,一般全连接层的参数也是最多的。

主要作用
全连接层的主要作用就是将前层(卷积、池化等层)计算得到的特征空间映射样本标记空间。简单的说就是将特征表示整合成一个值,其优点在于减少特征位置对于分类结果的影响,提高了整个网络的鲁棒性。
全连接在整个网络卷积神经网络中起到“分类器”的作用,如果说卷积层、池化层和激活函数等操作是将原始数据映射到隐层特征空间的话,全连接层则起到将学到的特征表示映射到样本的标记空间的作用。其实,就是把特征整合到一起,方便交给最后的分类器或者回归。

实际操作
在实际使用中,全连接层可由卷积操作实现:对前层是全连接的全连接层可以转换为卷积核为1*1的卷积;而前层是卷积层的全连接层可以转换为卷积核为前层卷积输出结果的高和宽一样大小的全局卷积。

一个通俗的例子:
以VGG-16为例,对224x224x3的输入,最后一层卷积可得输出为7x7x512,如后层是一层含4096个神经元的FC,则可用卷积核为7x7x512x4096的全局卷积来实现这一全连接运算过程,其中该卷积核参数如下:“filter size = 7, padding = 0, stride = 1, D_in = 512, D_out = 4096”经过此卷积操作后可得输出为1x1x4096。如需再次叠加一个2048的FC,则可设定参数为“filter size = 1, padding = 0, stride = 1, D_in = 4096, D_out = 2048”的卷积层操作(个人理解就是用1×1×4096×2048的卷积核来卷积)。(参考:https://www.zhihu.com/question/41037974/answer/150522307)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1321239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AT32F403如何扩大SRAM

配置方法 使用雅特力的ICP 进行配置(可在官网下载) (1)当连接上芯片后,点击设备操作->选择字节 (2)选择224KB SRAM (3)然后点击应用到设备,(可以点击从设备加载,来看当前的配置) (4)打开keil5魔术棒图标 ,将Target中的IRAM1第二个选项从0x10000改为0x3800。…

虚拟电厂 能源物联新方向

今年有多热&#xff1f;据上海市气象局官微消息&#xff0c;5月29日13时09分&#xff0c;徐家汇站气温达36.1℃&#xff0c;打破了百年来的当地5月份气温*高纪录。不仅如此&#xff0c;北京、四川、江西、湖南、广东、广西等地也频频发布高温预警。 伴随着居民用电急剧攀升&am…

4.1 媒资管理模块 - Nacos与Gateway搭建

文章目录 媒资管理模块 - 媒资项目搭建一、需求分析1.1 介绍1.2 数据模型1.3 分析网关 二、 搭建Nacos2.1 服务发现中心2.2.1 Maven2.2.2 配置Nacos 2.2 配置中心2.2.1 介绍2.2.2 Maven 坐标2.2.3 配置 content-api 工程2.2.4 配置 content-service 工程2.2.5 配置 system-api …

基础算法(5):滑动窗口

1.何为滑动窗口&#xff1f; 滑动窗口其实也是一种算法&#xff0c;主要有两类&#xff1a;一类是固定窗口&#xff0c;一类是可变窗口。固定的窗口只需要一个变量记录&#xff0c;而可变窗口需要两个变量。 2.固定窗口 就像上面这个图一样。两个相邻的长度为4的红色窗口&…

HTML---CSS美化网页元素

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 一.div 标签&#xff1a; <div>是HTML中的一个常用标签&#xff0c;用于定义HTML文档中的一个区块&#xff08;或一个容器&#xff09;。它可以包含其他HTML元素&#xff0c;如文本、图像…

探秘 AJAX:让网页变得更智能的异步技术(上)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

如何编写好的测试用例?

对于软件测试工程师来说&#xff0c;设计测试用例和提交缺陷报告是最基本的职业技能。是非常重要的部分。一个好的测试用例能够指示测试人员如何对软件进行测试。在这篇文章中&#xff0c;我们将介绍测试用例设计常用的几种方法&#xff0c;以及如何编写高效的测试用例。 一、…

iPhone 17Pro/Max或升级4800万像素长焦镜头,配备自研Wi-Fi 7芯片。

iPhone 16未至&#xff0c;关于iPhone 17系列的相关消息就已经放出&#xff0c;到底是谁走漏了风声。 海通国际证券技术分析师Jeff Pu近日发布报告称&#xff0c;苹果将为2025年推出的iPhone 17ProMax配备4800万像素的长焦镜头。经调查&#xff0c;该分析师认为提升iPhone拍摄方…

如何在华为云上购买ECS及以镜像的方式部署华为云欧拉操作系统 (HCE OS)

写在前面 工作中遇到&#xff0c;简单整理博文内容为 华为云开发者认证 实验笔记https://edu.huaweicloud.com/certificationindex/developer/9bf91efb086a448ab4331a2f53a4d3a1理解不足小伙伴帮忙指正 对每个人而言&#xff0c;真正的职责只有一个&#xff1a;找到自我。然后在…

Nginx快速入门:Nginx应用场景、安装与部署(一)

1. Nginx简介 Nginx 是一个高性能的 HTTP 和反向代理服务器&#xff0c;也是一个非常流行的开源 Web 服务器软件。它是由俄罗斯程序员 Igor Sysoev 开发的&#xff0c;最初是为了解决在高并发场景下的C10k 问题&#xff08;即一个服务器进程只能处理 10,000 个并发连接&#x…

早期的OCR是怎么识别图片上的文字的?

现在的OCR技术融合了人工智能技术&#xff0c;通过深度学习&#xff0c;无论是识别的准确率还是效果都非常不错&#xff0c;那您知道在早期的OCR是通过什么技术来实现的吗&#xff1f;如果您不知道&#xff0c;那么&#xff0c;就让我来告诉您&#xff1a;它主要是基于字符的几…

DiffUtil + RecyclerView 在 Kotlin中的使用

很惭愧, 做了多年的Android开发还没有使用过DiffUtil这样解放双手的工具。 文章目录 1 DiffUtil 用来解决什么问题?2 DiffUtil 是什么?3 DiffUtil的使用4 参考文章 1 DiffUtil 用来解决什么问题? 先举几个实际开发中的例子帮助我们感受下: 加载内容流时,第一次加载了ABC,…

数据分析思维导图

参考&#xff1a; https://zhuanlan.zhihu.com/p/567761684?utm_id0 1、数据分析步骤地图 2、数据分析基础知识地图 3、数据分析技术知识地图 4、数据分析业务流程 5、数据分析师能力体系 6、数据分析思路体系 7、电商数据分析核心主题 8、数据科学技能书知识地图 9、数据挖掘…

文章解读与仿真程序复现思路——电力系统自动化EI\CSCD\北大核心《基于碳捕集-电转气的矿区综合能源系统协同优化调度》

这个标题涉及到碳捕集、电力转化为气体&#xff08;可能是指电力转化为氢气等&#xff09;、矿区综合能源系统以及协同优化调度等概念。让我们逐步解读&#xff1a; 碳捕集&#xff08;Carbon Capture&#xff09;&#xff1a; 这指的是通过不同技术手段捕获和隔离工业过程中产…

输电线路定位:精确导航,确保电力传输安全

在现代社会中&#xff0c;电力作为生活的基石&#xff0c;其安全稳定运行至关重要。而输电线路作为电力传输的重要通道&#xff0c;其故障定位和修复显得尤为重要。恒峰智慧科技将为您介绍一种采用分布式行波测量技术的输电线路定位方法&#xff0c;以提高故障定位精度&#xf…

新版Android Studio Logcat 筛选日志

下载了新版的Android Studio&#xff0c;android-studio-2022.3.1.21-mac_arm&#xff0c;记录一下新版本AS的logcat过滤日志条件 1. 按照包名过滤 1.1 过滤当前包名的日志 package:mine 1.2 过滤其他包名日志 package:com.example.firstemptyapplication 2. 按照日志等级过滤…

STM32与Freertos入门(三)任务的创建、删除

1、串口配置 首先将串口进行配置&#xff0c;后续经常会应用&#xff0c;具体步骤点击&#xff1a;串口配置。 2、任务 创建一个任务&#xff0c;就是开辟一个空间、每个任务中都会有while&#xff08;1&#xff09;死循环。 2.1相关函数 动态创建&#xff1a;xTaskCreate…

Github 2023-12-19开源项目日报 Top10

根据Github Trendings的统计&#xff0c;今日(2023-12-19统计)共有10个项目上榜。根据开发语言中项目的数量&#xff0c;汇总情况如下&#xff1a; 开发语言项目数量Python项目4Rust项目2非开发语言项目2C#项目1TypeScript项目1 Avalonia: 跨平台UI框架和Avalonia XPF 创建周…

20231218给Firefly的AIO-3399J【RK3399】开发板刷Android12挖掘机方案

20231218给Firefly的AIO-3399J【RK3399】开发板刷Android12挖掘机方案 2023/12/18 21:07 一、整体编译Rockchip的的Android12的挖掘机方案&#xff01; 由于RK3399的Android12系统默认是IND工业方案&#xff0c;需要修改一下【为挖掘机方案】。 Z:\3TB\81rk_android12_220722\…

鸿蒙端H5容器化建设——JSB通信机制建设

1. 背景 2023年鸿蒙开发者大会上&#xff0c;华为宣布为了应对国外技术封锁的潜在风险&#xff0c;2024年的HarmonyOS NEXT版本中将不再兼容Android&#xff0c;并推出鸿蒙系统以及其自研的开发框架&#xff0c;形成开发生态闭环。同时&#xff0c;在更高维度上华为希望将鸿蒙…