【Pytorch】进阶学习:基于矩阵乘法torch.matmul()实现全连接层

news2024/11/17 7:56:13

【Pytorch】进阶学习:基于矩阵乘法torch.matmul()实现全连接层

在这里插入图片描述

🌈 个人主页:高斯小哥
🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望得到您的订阅和支持~
💡 创作高质量博文(平均质量分92+),分享更多关于深度学习、PyTorch、Python领域的优质内容!(希望得到您的关注~)


🌵文章目录🌵

  • 🚀一、引言
  • 🔍二、全连接层的基本原理
  • 🔩三、使用torch.matmul()实现全连接层
  • 🎛️四、使用PyTorch的nn.Linear模块实现全连接层
  • 🔎五、小结与注意事项
  • 🤝六、实战演练:构建简单的神经网络
  • 📚七、进阶学习:深度神经网络与全连接层
  • 🤝八、期待与你共同进步

🚀一、引言

  在深度学习的世界里,全连接层(Fully Connected Layer)是构建神经网络的基础组件之一。它实际上执行的就是矩阵乘法操作,将输入数据映射到输出空间。在PyTorch中,我们可以使用torch.matmul()函数来实现这一操作。本文将详细解释如何使用torch.matmul()实现全连接层,并通过实例展示其应用。

🔍二、全连接层的基本原理

  全连接层,也称为密集连接层或仿射层,其核心操作就是矩阵乘法。假设输入数据的形状为(batch_size, input_features),全连接层的权重矩阵形状为(output_features, input_features),偏置项的形状为(output_features,)。全连接层的输出可以通过以下公式计算得到:

output = input @ weight.t() + bias

这里,@ 表示矩阵乘法,.t() 表示转置操作。注意,权重矩阵的列数必须与输入数据的特征数相匹配,以便进行矩阵乘法。偏置项则是一个可选的加法操作,用于增加模型的灵活性。

🔩三、使用torch.matmul()实现全连接层

在PyTorch中,我们可以使用torch.matmul()函数来执行矩阵乘法操作,从而实现全连接层。下面是一个简单的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义全连接层的输入和输出特征数
input_features = 10
output_features = 5

# 创建一个随机的输入张量,形状为(batch_size, input_features)
batch_size = 32
input_tensor = torch.randn(batch_size, input_features)

# 初始化全连接层的权重和偏置项
weight = torch.randn(output_features, input_features)
bias = torch.randn(output_features)

# 使用torch.matmul()实现全连接层的计算
output_tensor = torch.matmul(input_tensor, weight.t()) + bias

# 查看输出张量的形状,应为(batch_size, output_features)
print(output_tensor.shape)  # 输出应为torch.Size([32, 5])

  在上面的代码中,我们首先定义了全连接层的输入和输出特征数。然后,我们创建了一个随机的输入张量input_tensor,其形状为(batch_size, input_features)。接下来,我们初始化了全连接层的权重weight和偏置项bias。最后,我们使用torch.matmul()函数执行矩阵乘法操作,并将结果加上偏置项,得到输出张量output_tensor。通过打印输出张量的形状,我们可以验证其是否符合预期。

🎛️四、使用PyTorch的nn.Linear模块实现全连接层

  虽然我们可以使用torch.matmul()手动实现全连接层,但在实际开发中,更常见的是使用PyTorch提供的nn.Linear模块来创建全连接层。这个模块封装了权重和偏置项的初始化、矩阵乘法以及偏置项的加法操作,使得全连接层的实现更加简洁和方便。

下面是一个使用nn.Linear模块实现全连接层的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义全连接层的输入和输出特征数
input_features = 10
output_features = 5

# 创建一个随机的输入张量,形状为(batch_size, input_features)
batch_size = 32
input_tensor = torch.randn(batch_size, input_features)

# 使用nn.Linear模块创建全连接层
linear_layer = nn.Linear(input_features, output_features)

# 将输入张量传递给全连接层进行计算
output_tensor = linear_layer(input_tensor)

# 查看输出张量的形状
print(output_tensor.shape)  # 输出应为torch.Size([32, 5])

  在上面的代码中,我们直接使用nn.Linear(input_features, output_features)创建了一个全连接层对象linear_layer。然后,我们将输入张量input_tensor传递给这个全连接层对象,即可得到输出张量output_tensor。这种方式比手动使用torch.matmul()更加简洁,同时也提供了更多的功能和灵活性,例如权重和偏置项的初始化方法、是否包含偏置项等。

🔎五、小结与注意事项

  通过本文的介绍,我们了解了全连接层的基本原理,并学习了如何使用torch.matmul()函数以及nn.Linear模块来实现全连接层。在实际应用中,我们可以根据具体需求选择合适的方式来实现全连接层。需要注意的是,在使用torch.matmul()时,要确保输入张量和权重矩阵的形状匹配,以避免出错。

🤝六、实战演练:构建简单的神经网络

  理解了全连接层的工作原理和如何使用torch.matmul()后,我们可以进一步构建一个简单的神经网络来加深理解。以下是一个使用PyTorch构建和训练简单神经网络的示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 定义全连接层的输入和输出特征数
input_features = 10
output_features = 1

batch_size = 32

# 假设的输入和输出数据
X_train = torch.randn(100, input_features)
y_train = torch.randint(0, 2, (100,))  # 假设是二分类问题

# 将数据包装成TensorDataset和DataLoader
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# 定义简单的神经网络模型
class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        x = self.sigmoid(x)
        return x


# 初始化模型、损失函数和优化器
model = SimpleNN(input_features, output_features)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        # 前向传播
        outputs = model(inputs)

        # 计算损失
        loss = criterion(outputs.squeeze(), targets.float())

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 测试模型
with torch.no_grad():
    test_data = torch.randn(5, input_features)
    predictions = model(test_data)
    print(predictions)

  在上面的代码中,我们首先定义了一个简单的神经网络模型SimpleNN,它只包含一个全连接层和一个Sigmoid激活函数。然后,我们初始化了模型、损失函数(二分类交叉熵损失)和优化器(随机梯度下降)。接着,我们进行了模型的训练过程,包括前向传播、损失计算、反向传播和参数更新。最后,我们对模型进行了测试,输入了一些随机生成的数据并得到了预测结果。

📚七、进阶学习:深度神经网络与全连接层

  全连接层在深度神经网络中扮演着重要的角色。随着网络深度的增加,全连接层可以帮助模型捕获更复杂的特征和模式。然而,在实际应用中,我们还需要注意一些问题,如过拟合、计算效率等。为了解决这些问题,我们可以采用一些技巧和方法,如添加正则化项、使用Dropout层、优化网络结构等。

  此外,随着深度学习技术的不断发展,越来越多的新型网络结构被提出,如卷积神经网络(CNN)、循环神经网络(RNN)等。这些网络结构在处理图像、语音、文本等不同类型的数据时具有独特的优势。因此,我们可以进一步学习这些网络结构,并结合全连接层来构建更强大的深度学习模型。

🤝八、期待与你共同进步

  🌱 亲爱的读者,非常感谢你每一次的停留和阅读!你的支持是我们前行的最大动力!🙏

  🌐 在这茫茫网海中,有你的关注,我们深感荣幸。你的每一次点赞👍、收藏🌟、评论💬和关注💖,都像是明灯一样照亮我们前行的道路,给予我们无比的鼓舞和力量。🌟

  📚 我们会继续努力,为你呈现更多精彩和有深度的内容。同时,我们非常欢迎你在评论区留下你的宝贵意见和建议,让我们共同进步,共同成长!💬

  💪 无论你在编程的道路上遇到什么困难,都希望你能坚持下去,因为每一次的挫折都是通往成功的必经之路。我们期待与你一起书写编程的精彩篇章! 🎉

  🌈 最后,再次感谢你的厚爱与支持!愿你在编程的道路上越走越远,收获满满的成就和喜悦!祝你编程愉快!🎉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1499511.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在Jetson Xavier NX 开发板上使用VScode执行ROS程序详细过程

1.创建 ROS 工作空间ws 在home下打开终端输入下面指令 mkdir -p xxx_ws/src(必须得有 src) cd 自己命名_ws catkin_make2.启动 vscode cd 自己命名_ws code .3.vscode 中编译 ros 快捷键 ctrl shift B 调用编译,在上方弹窗位置选择:catkin_make:build 可以点击…

Find My产品越来越得到市场认可,伦茨科技ST17H6x芯片赋能厂家

苹果发布AirTag发布以来,大家都更加注重物品的防丢,苹果的 Find My 就可以查找 iPhone、Mac、AirPods、Apple Watch,如今的Find My已经不单单可以查找苹果的设备,随着第三方设备的加入,将丰富Find My Network的版图。产…

数据可视化原理-腾讯-分类散点图

在做数据分析类的产品功能设计时,经常用到可视化方式,挖掘数据价值,表达数据的内在规律与特征展示给客户。 可是作为一个产品经理,(1)如果不能够掌握各类可视化图形的含义,就不知道哪类数据该用…

阿珊详解Vue路由的两种模式:hash模式与history模式

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

PostgreSQL 流复制

文章目录 1.流复制介绍2.异步流复制2.1.主库部署2.2.备库部署2.3.测试 3.同步复制3.1.主库部署3.2.备库部署3.3.测试 4.主备切换 开源中间件 # PostgreSQLhttps://iothub.org.cn/docs/middleware/ https://iothub.org.cn/docs/middleware/postgresql/postgres-stream/1.流复制…

R语言基础的代码语法解译笔记

1、双冒号,即:“::” 要使用某个包里的函数,通常做法是先加载(library)包,再调用函数。最新加载的包的namespace会成为最新的enviroment,某些情况下可能影响函数的结果。而package name::funct…

怎么把amv格式转换成MP4? 几个步骤轻松搞定~

AMV文件格式的起源可以追溯到中国公司Actions Semiconductor,最初作为其MP4播放器的专有视频格式。在数码媒体领域蓬勃发展的时期,AMV格式因其小巧、高度压缩的特性而备受青睐,为便携设备提供了一种有效的视频存储解决方案。 MP4文件格式的特…

游泳池泵/过滤器/氯化器/潜水泵上架美国亚马逊UL1081测试

Ul是美国保险商试验所(UnderwritersLaboratoriesInc.)的简写。UL安全试验所是美国最有权威的,也是世界上从事安全试验和鉴定的较大的民间机构。它是一个独立的、营利的、为公共安全做试验的专业机构。它采用科学的测试方法来研究确定各种材料…

泛微OA服务器获取 token

泛微OA服务器获取 token 文章目录 泛微OA服务器获取 token一、泛微官方方法1 ecology 系统配置2 发放/生成许可证(appid)3 限制许可证使用ip地址(该步骤也可以跳过)4 使用 postman 注册5 获取 token6 访问业务系统接口 二、java 代码获取 token三、封装到…

【随笔记】小程序轮播图,一屏显示三个swiper-item

常见的轮播是一屏显示一个swiper-item,有的时候需要一屏显示三个swiper-item,左右两边都显示出一点 【目前小程序基础库2.12.3 效果正常,3.几的效果会有点不正常】 效果图 wxml <!-- 轮播begin --> <swiper wx:if="{{up_down}}" class="card-swipe…

GO: 快速升级Go版本

由于底层依赖升级了&#xff0c;那我们也要跟着升&#xff0c;go老版本已经不足满足需求了&#xff0c;必须要将版本升级到1.22.0以上 查看当前Go版本 命令查看go版本 go version[rootlocalhost local]# go version go version go1.21.4 linux/amd64 [rootlocalhost local]# …

3/7作业

信号同步 #include <stdio.h> #include <string.h> #include <unistd.h> #include <stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <pthread.h> #include <semaphore.h> sem_t…

Windows系统安装Jupyter Notebook并实现公网访问内网笔记服务

文章目录 1.前言2.Jupyter Notebook的安装2.1 Jupyter Notebook下载安装2.2 Jupyter Notebook的配置2.3 Cpolar下载安装 3.Cpolar端口设置3.1 Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1.前言 在数据分析工作中&#xff0c;使用最多的无疑就是各种函数、图表、…

java操作内存,简单讲解varhandle的使用

概述&#xff1a;按理说jdk8的unsafe类就够用了&#xff0c;估计是因为不安全的原因&#xff0c;到jdk9出了个varhandle类&#xff0c;到jdk21的时候出了Arena和MemorySegment,基本就可以取代unsafe类的使用了。这里我主要讲varhandle类&#xff0c;因为大部分人升级jdk顶多升到…

Excel小技巧-筛选带删除线的数据并删除

Excel小技巧-筛选带删除线的数据并删除 1、替换删除线2、筛选空行并删除 今天同事使用 Excel 的时候遇到一个需求&#xff0c;有些内容不在需要时会被标记删除线&#xff0c;后面再删除&#xff0c;但是由于数据比较多&#xff0c;不方便一个个删除&#xff0c;有没有什么办法能…

STM32 通过Modbus协议更改内部Flash(模仿EEPROM)的运行参数

main.c测试 uint8_t uart1RxBuf[64]{0};uint8_t Adc1ConvEnd0; uint8_t Adc2ConvEnd0;int main(void) {/* USER CODE BEGIN 1 *//* USER CODE END 1 *//* MCU Configuration--------------------------------------------------------*//* Reset of all peripherals, Initial…

使用腾讯云快速搭建WordPress网站流程详解

专栏系列文章&#xff1a; WordPress建站主题美化系列教程https://blog.csdn.net/seeker1994/category_12184577.html 一文搞懂WordPress是什么&#xff1f;为什么用它建站&#xff1f;怎么安装与部署&#xff1f; 初次安装WordPress后如何进行网站设置&#xff08;主题安装、…

阿里云k8s环境下,因slb限额导致的发布事故

一、背景 阿里云k8s容器&#xff0c;在发布java应用程序的时候&#xff0c;客户端访问出现500错误。 后端服务是健康且可用的&#xff0c;网关层大量500错误请求&#xff0c;slb没有流入和流出流量。 经过回滚&#xff0c;仍未能解决错误。可谓是一次血的教训&#xff0c;特…

数组与指针之二——二级指针之一

定义是这样&#xff1a; 多级指针&#xff08;二级指针&#xff09;&#xff0c;C语言多级指针的用法详解 (biancheng.net) 这是针对变量&#xff0c;且是一级一级的取的。但是我们经常要面对数组&#xff0c;用到二级指针。如前面第一篇所述&#xff0c;对一维数组名取地址&…