【深度学习实验】前馈神经网络(三):自定义多层感知机(激活函数logistic、线性层算Linear)

news2024/9/28 3:25:05

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 构建数据集

 2. 激活函数logistic

3. 线性层算子 Linear

4. 两层的前馈神经网络MLP

5. 模型训练


一、实验介绍

  • 本实验实现了一个简单的两层前馈神经网络
    • 激活函数logistic
    • 线性层算子Linear

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch
from torch import nn

1. 构建数据集

input = torch.ones((1, 10))

         创建了一个输入张量`input`,大小为(1, 10)。

 2. 激活函数logistic

def logistic(z):
    return 1.0 / (1.0 + torch.exp(-z))

        logistic函数的特点是将输入值映射到一个介于0和1之间的输出值,可以看作是一种概率估计。当输入值趋近于正无穷大时,输出值趋近于1;当输入值趋近于负无穷大时,输出值趋近于0。因此,logistic函数常用于二分类问题,将输出值解释为概率值,可以用于预测样本属于某一类的概率。在神经网络中,logistic函数的引入可以引入非线性特性,使得网络能够学习更加复杂的模式和表示。

3. 线性层算子 Linear

class Linear(nn.Module):
    def __init__(self, input_size, output_size):
        super(Linear, self).__init__()
        self.params = {}
        self.params['W'] = nn.Parameter(torch.randn(input_size, output_size, requires_grad=True))
        self.params['b'] = nn.Parameter(torch.randn(1, output_size, requires_grad=True))
        self.grads = {}
        self.inputs = None

    def forward(self, inputs):
        self.inputs = inputs
        outputs = torch.matmul(inputs, self.params['W']) + self.params['b']
        return outputs
  • Linear类是一个自定义的线性层,继承自nn.Module
    • 它具有两个参数:input_sizeoutput_size,分别表示输入和输出的大小。
  • 在初始化时,创建了两个参数:Wb,分别代表权重和偏置,都是可训练的张量,并通过nn.Parameter进行封装。
    • paramsgrads是字典类型的属性,用于存储参数和梯度;
    • inputs是一个临时变量,用于存储输入。
  • forward方法实现了前向传播的逻辑,利用输入和参数计算输出。

4. 两层的前馈神经网络MLP

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = Linear(input_size, hidden_size)
        self.fc2 = Linear(hidden_size, output_size)

    def forward(self, x):
        z1 = self.fc1(x)
        a1 = logistic(z1)
        z2 = self.fc2(a1)
        a2 = logistic(z2)
        return a2
  • 初始化时创建了两个线性层Linear对象:fc1fc2
  • forward方法实现了整个神经网络的前向传播过程:
    • 输入x首先经过第一层线性层fc1
    • 然后通过logistic函数进行激活,
    • 再经过第二层线性层fc2
    • 最后再经过一次logistic函数激活,
    • 并返回最终的输出。

5. 模型训练

input_size, hidden_size, output_size = 10, 5, 2
net = MLP(input_size, hidden_size, output_size)
output = net(input)
print(output)
  • 定义了三个变量input_sizehidden_sizeoutput_size,分别表示输入大小、隐藏层大小和输出大小。
  • 创建了一个MLP对象net,并将输入input传入模型进行前向计算,得到输出output。最后将输出打印出来。

6. 代码整合

# 导入必要的工具包
import torch
from torch import nn


# 线性层算子,请一定注意继承自 nn. Module, 这会帮你解决许多细节上的问题
class Linear(nn.Module):
    def __init__(self, input_size, output_size):
        super(Linear, self).__init__()
        self.params = {}
        self.params['W'] = nn.Parameter(torch.randn(input_size, output_size, requires_grad=True))
        self.params['b'] = nn.Parameter(torch.randn(1, output_size, requires_grad=True))
        self.grads = {}
        self.inputs = None

    def forward(self, inputs):
        self.inputs = inputs
        outputs = torch.matmul(inputs, self.params['W']) + self.params['b']
        return outputs


# 实现一个两层的前馈神经网络
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = Linear(input_size, hidden_size)
        self.fc2 = Linear(hidden_size, output_size)

    def forward(self, x):
        z1 = self.fc1(x)
        a1 = logistic(z1)
        z2 = self.fc2(a1)
        a2 = logistic(z2)
        return a2


# Logistic 函数
def logistic(z):
    return 1.0 / (1.0 + torch.exp(-z))

input = torch.ones((1, 10))
input_size, hidden_size, output_size = 10, 5, 2
net = MLP(input_size, hidden_size, output_size)
output = net(input)
print(output)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1026147.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一、【漏洞复现系列】Tomcat文件上传 (CVE-2017-12615)

1.1、漏洞原理 描述: Tomcat 是一个小型的轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP 程序的首选。 攻击者将有可能可通过精心构造的攻击请求数据包向服务器上传包含任意代码的 JSP 的webshell文件&#x…

100G QSFP28 100km光模块最新解决方案

随着信息时代的到来,数据传输的速度和距离要求越来越高。目前,易天光通信发布了具有超低成本、可实现100G超长距离传输新方案——100G QSFP28 100km光模块,该方案是在100G ZR4 80km光模块上的全面升级。 一、产品概述 100G ZR4 100km是专为…

requests模块高级用法练习

文章目录 模拟浏览器指纹发送get请求发送post请求文件上传服务器超时 模拟浏览器指纹 打开http://10.9.75.164/php/functions/setcookie.php网页,找到请求头的UA字段,这段信息是浏览器的指纹(包括当前系统、浏览器名称和版本)&am…

【再识C进阶3(上)】详细地认识字符串函数、进行模拟字符串函数以及拓展内容

小编在写这篇博客时,经过了九一八,回想起了祖国曾经的伤疤,勿忘国耻,振兴中华!加油,逐梦少年! 前言 💓作者简介: 加油,旭杏,目前大二,…

【短文】sambe添加用户时报错Failed to add entry for user

2023年9月20日,周三晚上 Samba fails to add a user entry, how do I fix this? - Ask Ubuntu 也就是说,添加的sambe用户必须是Linux操作系统的用户

2023/09/20 day4 qt

做一个动态指针钟表 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QPainter> //绘制事件类 #include <QPaintEvent> //画家类 #include <QTime> #include <QTimer> #include <QTimerEvent> QT_BEGIN…

k8s使用时无法ping通服务器From IP地址 icmp_seq=1 Destination Host Unreachable

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

canvas-绘图库fabric.js简介

一般情况下简单的绘制&#xff0c;其实canvas原生方法也可以满足&#xff0c;比如画个线&#xff0c;绘制个圆形、正方形、加个文案。 let canvas document.getElementById(canvas);canvas.width 1200;canvas.height 600;canvas.style.width 1200px;canvas.style.height 6…

Canal实现Mysql数据同步至Redis、Elasticsearch

文章目录 1.Canal简介1.1 MySQL主备复制原理1.2 canal工作原理 2.开启MySQL Binlog3.安装Canal3.1 下载Canal3.2 修改配置文件3.3 启动和关闭 4.SpringCloud集成Canal4.1 Canal数据结构![在这里插入图片描述](https://img-blog.csdnimg.cn/c64b40c2231a4ea39a95aac81d771bd1.pn…

kafka消费者多线程开发

目录 前言 kafka consumer 设计原理 多线程的方案 参考资料 前言 目前&#xff0c;计算机的硬件条件已经大大改善&#xff0c;即使是在普通的笔记本电脑上&#xff0c;多核都已经是标配了&#xff0c;更不用说专业的服务器了。如果跑在强劲服务器机器上的应用程序依然是单…

java框架-Spring-容器创建过程

java框架-Spring-容器创建源码

pip pip3安装库时都指向python2的库

当在python3的环境下使用pip3安装库时&#xff0c;发现居然都指向了python2的库 pip -V pip3 -V安装命令更改为&#xff1a; python3 -m pip install <package>

CCC数字钥匙设计【BLE】--URSK管理

1、URSK创建流程 URSK的英文全称为&#xff1a;UWB Ranging Secret Key&#xff0c;即UWB安全测距密钥。 在车主配对时会生成URSK&#xff0c;且在车主配对期间&#xff0c;车辆不得尝试生成第二个URSK。 URSK示例: ed07a80d2beb00f785af2627c96ae7c118504243cb2c3226b3679da…

面向面试知识--MySQL数据库与索引

面向面试知识–MySQL数据库与索引 优化难点与面试点 什么是MySQL索引&#xff1f; 索引的MySQL官方定义&#xff1a;索引是帮助MySQL快速获取数据的数据结构。 动力节点原文&#xff1a; MysQL官方对于索引的定义:索引是帮助MySQL高效获取数据的数据结构。 MysQL在存储数据之…

问题usr/bin/env: “python‘: Too many levels of symbolic links太多层链接的bug pycharm

问题描述 解决&#xff1a;建议不要用过去的conda环境了&#xff0c;直接新建一个环境&#xff0c;然后在图片这个步骤的时候务必选择现有的解释器 。&#xff08;产生问题的原因可能就是新建的解释器太多了&#xff09;

Mermaid画流程图可以实现从一条线中间引出另外一条线吗

这张图中开始和操作1之间引出的一条线要怎么表示啊&#xff01;&#xff01;&#xff01; Mermaid是不能实现这样的画法的吗&#xff1f;可是为什么老师就可以画出来&#xff1f;&#xff1f;&#xff1f; 求大佬指教&#xff01;&#xff01;&#xff01;&#xff01;

现场总线学习

文章目录 1.现场总线现状2.数据编码2.1 数字数据的数字编码2.2 数字数据的模拟编码 3.通信方式&#xff01;&#xff01;&#xff01;4.局域网及其拓扑结构5.工业总线协议6.为什么要在can协议的控制器和bus总线之间&#xff0c;连接一个can收发器&#xff1f;7.那其他协议也需要…

vue修改node_modules打补丁步骤和注意事项

当我们使用 npm 上的第三方依赖包&#xff0c;如果发现 bug 时&#xff0c;怎么办呢&#xff1f; 想想我们在使用第三方依赖包时如果遇到了bug&#xff0c;通常解决的方式都是绕过这个问题&#xff0c;使用其他方式解决&#xff0c;较为麻烦。或者给作者提个issue&#xff0c;然…

dev board sig技术文章:轻量系统适配ARM架构芯片平台

摘要&#xff1a;本文简单介绍OpenHarmony轻量系统移植&#xff0c;会分多篇 适合群体&#xff1a;想自己动手移植OpenHarmony轻量系统的朋友 开始尝试讲解一下系统的移植&#xff0c;主要是轻量系统&#xff0c;也可能会顺便讲下L1移植。 1.1移植类型 OpenHarmony轻量系统的…

腾讯云服务器收费价格表(腾讯云服务器租用价格表)

作为国内领先的云计算服务提供商&#xff0c;腾讯云凭借其稳定、安全、高效的特点&#xff0c;备受用户青睐。本文将详细介绍腾讯云服务器的收费价格表及使用场景&#xff0c;帮助大家更好地了解并选择合适的云服务器方案。 一、轻量应用服务器 轻量应用服务器是一款开箱即用的…