【深度学习实验】前馈神经网络(三):自定义两层前馈神经网络(激活函数logistic、线性层算子Linear)

news2025/1/13 10:32:14

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 构建数据集

 2. 激活函数logistic

3. 线性层算子 Linear

4. 两层的前馈神经网络MLP

5. 模型训练


一、实验介绍

  • 本实验实现了一个简单的两层前馈神经网络
    • 激活函数logistic
    • 线性层算子Linear

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch
from torch import nn

1. 构建数据集

input = torch.ones((1, 10))

         创建了一个输入张量`input`,大小为(1, 10)。

 2. 激活函数logistic

def logistic(z):
    return 1.0 / (1.0 + torch.exp(-z))

        logistic函数的特点是将输入值映射到一个介于0和1之间的输出值,可以看作是一种概率估计。当输入值趋近于正无穷大时,输出值趋近于1;当输入值趋近于负无穷大时,输出值趋近于0。因此,logistic函数常用于二分类问题,将输出值解释为概率值,可以用于预测样本属于某一类的概率。在神经网络中,logistic函数的引入可以引入非线性特性,使得网络能够学习更加复杂的模式和表示。

3. 线性层算子 Linear

class Linear(nn.Module):
    def __init__(self, input_size, output_size):
        super(Linear, self).__init__()
        self.params = {}
        self.params['W'] = nn.Parameter(torch.randn(input_size, output_size, requires_grad=True))
        self.params['b'] = nn.Parameter(torch.randn(1, output_size, requires_grad=True))
        self.grads = {}
        self.inputs = None

    def forward(self, inputs):
        self.inputs = inputs
        outputs = torch.matmul(inputs, self.params['W']) + self.params['b']
        return outputs
  • Linear类是一个自定义的线性层,继承自nn.Module
    • 它具有两个参数:input_sizeoutput_size,分别表示输入和输出的大小。
  • 在初始化时,创建了两个参数:Wb,分别代表权重和偏置,都是可训练的张量,并通过nn.Parameter进行封装。
    • paramsgrads是字典类型的属性,用于存储参数和梯度;
    • inputs是一个临时变量,用于存储输入。
  • forward方法实现了前向传播的逻辑,利用输入和参数计算输出。

4. 两层的前馈神经网络MLP

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = Linear(input_size, hidden_size)
        self.fc2 = Linear(hidden_size, output_size)

    def forward(self, x):
        z1 = self.fc1(x)
        a1 = logistic(z1)
        z2 = self.fc2(a1)
        a2 = logistic(z2)
        return a2
  • 初始化时创建了两个线性层Linear对象:fc1fc2
  • forward方法实现了整个神经网络的前向传播过程:
    • 输入x首先经过第一层线性层fc1
    • 然后通过logistic函数进行激活,
    • 再经过第二层线性层fc2
    • 最后再经过一次logistic函数激活,
    • 并返回最终的输出。

5. 模型训练

input_size, hidden_size, output_size = 10, 5, 2
net = MLP(input_size, hidden_size, output_size)
output = net(input)
print(output)
  • 定义了三个变量input_sizehidden_sizeoutput_size,分别表示输入大小、隐藏层大小和输出大小。
  • 创建了一个MLP对象net,并将输入input传入模型进行前向计算,得到输出output。最后将输出打印出来。

6. 代码整合

# 导入必要的工具包
import torch
from torch import nn


# 线性层算子,请一定注意继承自 nn. Module, 这会帮你解决许多细节上的问题
class Linear(nn.Module):
    def __init__(self, input_size, output_size):
        super(Linear, self).__init__()
        self.params = {}
        self.params['W'] = nn.Parameter(torch.randn(input_size, output_size, requires_grad=True))
        self.params['b'] = nn.Parameter(torch.randn(1, output_size, requires_grad=True))
        self.grads = {}
        self.inputs = None

    def forward(self, inputs):
        self.inputs = inputs
        outputs = torch.matmul(inputs, self.params['W']) + self.params['b']
        return outputs


# 实现一个两层的前馈神经网络
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = Linear(input_size, hidden_size)
        self.fc2 = Linear(hidden_size, output_size)

    def forward(self, x):
        z1 = self.fc1(x)
        a1 = logistic(z1)
        z2 = self.fc2(a1)
        a2 = logistic(z2)
        return a2


# Logistic 函数
def logistic(z):
    return 1.0 / (1.0 + torch.exp(-z))

input = torch.ones((1, 10))
input_size, hidden_size, output_size = 10, 5, 2
net = MLP(input_size, hidden_size, output_size)
output = net(input)
print(output)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1036983.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

41. Linux系统配置FTP服务器并在QT中使用QFtp实现文件上传

1. 说明 这篇博客主要记录一些在Linux系统中搭建FTP服务器时踩过的一些坑,以及在使用QFtp上传文件时需要注意的问题。 2. FTP环境搭建 在linux系统中,需要安装vsftpd,可以在终端中输入下面的命令进行安装: sudo apt-get install vsftpd使用上述命令安装后,系统中会有一…

day28IO流(字节流字符流)

1. IO概述 1.1 什么是IO 生活中,你肯定经历过这样的场景。当你编辑一个文本文件,忘记了ctrls ,可能文件就白白编辑了。当你电脑上插入一个U盘,可以把一个视频,拷贝到你的电脑硬盘里。那么数据都是在哪些设备上的呢&a…

机试算法学习

又到了一年一度的校招干饭环节,本人不得已以应届生的身份卷入了这场洪流,让我们各自加油吧! 蛇形矩阵 xx机考编程题 题目描述 输入两个整数 n和 m,输出一个 n 行 m 列的矩阵,将数字 1到 nm按照回字蛇形填充至矩阵中…

【前段基础入门之】=>HTML 标签元素

前言: 在前一章节中,我们讲解认识了,HTML 的概念,以及它的标准文档结构,所以本章节就带来 HTML 学习的第二步,学习了解HTML 的排版标签元素。 文章目录 文档排版标签元素语义化标签块级元素与行内元素文本标…

Linux桌面环境中应用程序无法启动图形交互界面

现象: 点击永中office或者金山office快捷图标无法启动对应的程序。 从命令行执行对应的程序则提示 按照提示安装组件 再次执行命令行程序 原因探析: /opt/Yozosoft/Yozo_Office/Yozo_Writer.bin: error while loading shared libraries: libgdk-x11-2.0.…

SQL 如何提取多级分类目录

前言 POI数据处理,原始数据为csv格式,整理入库至PostGreSQL,本例使用PostGreSQL13版本。 一、POI POI(一般作为Point of Interest的缩写,也有Point of Information的说法),通常称作兴趣点&am…

Cloudflare分析第一天:简单的算法反混淆

记录1: Cloudflare 加密方式为动态JS,每次请求JS文件都会变化,笨方式,先复制一份出来分析看! 原JS: window._cf_chl_opt.uaSR true; window._cf_chl_opt.uaO false; function(ia, fy, fz, fA, fB, fC, fM, fV, fW…

Windows 11 家庭中文版添加本地安全策略

一、报错 Windows11中打开本地组策略编辑器(cmd中输入gpedit.msc),报错: 二、解决 1、新建txt文件,文件名任意,将下面的内容复制粘贴进去。2、将文件后缀名由txt改为cmd。3、以管理员身份执行该cmd文件,安装本地安全…

java Spring Boot生成图片二维码

首先 我们要引入依赖 pom.xml中插入 <dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.4.1</version> </dependency> <dependency><groupId>com.google.zxing</grou…

Jetpack Compose 的简单 MVI 框架

Jetpack Compose 的简单 MVI 框架 在 Jetpack Compose 应用程序中管理状态的一种简单方法 选择正确的架构是至关重要的&#xff0c;因为架构变更后期代价高昂。MVP已被MVVM和MVI取代&#xff0c;而MVI更受欢迎。MVI通过强制实施结构化的状态管理方法&#xff0c;只在reducer中…

Linux 快捷键

1、快捷键小操作 1.1、ctrl c 强制停止 Linux某些程序的运行&#xff0c;如果想要强制停止它&#xff0c;可以使用快捷键ctrl c 命令输入错误&#xff0c;也可以通过快捷键ctrl c&#xff0c;退出当前输入&#xff0c;重新输入 1.2、ctrl d 退出或登出 可以通过快捷键&…

使用Mybatis generator自动生成代码,仅限Oracle数据库

一、使用Mybatis generator自动生成代码&#xff0c;仅限Oracle数据库 使用Mybatis generator自动生成代码&#xff0c;仅限Oracle数据库 一、在pom.xml文件中引入所需要的依赖和插件 <dependency><groupId>org.mybatis.generator</groupId><artifactI…

VUE日期只选择日月,表格导入功能,表格下载模版功能

1.日期选择日月&#xff1a;参考https://blog.csdn.net/Oct_Somnus/article/details/129989865?spm1001.2101.3001.6661.1&utm_mediumdistribute.pc_relevant_t0.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1-129989865-blog-116654979.235%5Ev38%5Epc_relevant_sort_b…

轻松使用androidstudio交叉编译libredwg库

对于安卓或嵌入式开发者而言,交叉编译是再熟悉不过的操作了,可是对于一些刚入门或初级开发者经常会遇到这样的问题:如何交叉编译C++库来生成安卓下的so库呢? 最近有一些粉丝找到我求救,那么我最近刚好有空大致研究了下,帮他们成功编译了其中一个libredwg的C++库,这篇文章…

Java 21 / JDK 21 (LTS) GA

Java 21 / JDK 21 已正式 GA&#xff0c;此版本是继 JDK 17 后的长期支持版本 (LTS)&#xff0c;Oracle 将为其提供至少八年的技术支持和更新。 本版本是Java SE平台21版的参考实现&#xff0c;由Java社区流程中的JSR 396指定。 正式稳定功能 JEP 444&#xff1a;虚拟线程JEP…

测试C#图像文本识别模块Tesseract的基本用法

微信公众号“dotNET跨平台”的文章《c#实现图片文体提取》&#xff08;参考文献3&#xff09;介绍了C#图像文本识别模块Tesseract&#xff0c;后者是tesseract-ocr&#xff08;参考文献2&#xff09; 的C#封装版本&#xff0c;目前版本为5.2&#xff0c;关于Tesseract的详细介绍…

windows上搭建llama小型私有模型

导言 llama官网是需要多读读的 openAI的付费&#xff0c;让学习LLM的成本不可控。为了省钱&#xff0c;搭建本地LLAMA模型 我的笔记本是近10年前买的配置一般的windows 目标 本地llm可以运行使用llama-cpp-python调用本地llm使用langchain/openai调用本地llm 需要重点说下&…

ESP8266 WiFi物联网智能插座—上位机和下位机通信协议

目录 1、配置节点协议 2、控制节点继电器开关协议 3、节点周期上报数据协议 4、升级节点协议 5、重启节点 本项目自定义了一套上位机和下位机通信协议&#xff0c;协议并不复杂&#xff0c;包含&#xff1a;配置节点、控制节点继电器开关、节点周期上报数据、升级节点和重启节点…

设计模式篇---桥接模式

文章目录 概念结构实例总结 概念 桥接模式&#xff1a;将抽象部分与它的实现部分解耦&#xff0c;使得两者都能够独立变化。 毛笔和蜡笔都属于画笔&#xff0c;假设需要有大、中、小三种型号的画笔&#xff0c;绘画出12种颜色&#xff0c;蜡笔需要3*1236支&#xff0c;毛笔需要…

大数据之Flume

Flume概述 一个高可用&#xff08;稳定&#xff09;&#xff0c;高可靠&#xff08;稳定&#xff09;&#xff0c;分布式的海量日志采集&#xff0c;聚合和传输的系统。Flume基于流式架构&#xff0c;灵活简单。日志文件即txt文件&#xff0c;不能传输音频&#xff0c;视频&am…