【深度学习实验】前馈神经网络(五):自定义线性模型:前向传播、反向传播算法(封装参数)

news2024/12/23 0:32:18

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 线性模型Linear类

a. 构造函数__init__

b. __call__(self, x)方法

c. 前向传播forward

d. 反向传播backward

2. 模型训练

3. 代码整合


一、实验介绍

  • 实现线性模型(Linear类)
    • 实现前向传播forward
    • 实现反向传播backward

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch

1. 线性模型Linear类

a. 构造函数__init__

  def __init__(self, input_size, output_size):
        self.params = {}
        self.params['W'] = nn.Parameter(torch.randn(input_size, output_size, requires_grad=True))
        self.params['b'] = nn.Parameter(torch.randn(1, output_size, requires_grad=True))
        self.inputs = None
        self.grads = {}
  • 成员变量:

    • params:用于保存模型的参数,包括权重矩阵 W 和偏置向量 b
    • inputs:保存输入数据的变量。
    • grads:保存参数的梯度的变量。

b. __call__(self, x)方法

    __call__(self, x)方法使得该类的实例可以像函数一样被调用。它调用了forward(x)方法,将输入的x传递给前向传播方法。

 def __call__(self, x):
        return self.forward(x)

c. 前向传播forward

    def forward(self, inputs):
        self.inputs = inputs
        outputs = torch.matmul(self.inputs, self.params['W']) + self.params['b']

        return outputs

在前向传播中,输入数据经过线性变换操作得到输出:

  • 在构造函数中,使用 nn.Parameter 将随机初始化的权重矩阵 W 和偏置向量 b 包装成可训练的参数。
  • 在 forward 方法中,输入数据 inputs 与权重矩阵 W 相乘,然后加上偏置向量 b,得到输出值 outputs
  • forward 方法返回计算得到的输出值。

d. 反向传播backward

  def backward(self, grads=None):
        if grads == None:
            grads = torch.ones(self.params['W'].shape)
        self.grads['w'] = torch.matmul(self.inputs.T, grads)
        self.grads['b'] = torch.sum(grads, dim=0)
        return torch.matmul(grads, self.params['W'].T)

   backward(self, grads=None)方法执行线性变换的反向传播:

  • 它接受一个可选的参数grads,用于传递输出的梯度。
  • 如果没有提供grads,则默认为全1的张量,表示对输出的梯度都为1。
  • 在线性变换中,计算输入的梯度需要使用输出的梯度和当前输入值。这里使用了矩阵乘法和求和操作来计算参数的梯度和输入的梯度
  • 返回计算得到的输入梯度。

2. 模型训练

net = Linear(4, 2)
x = torch.tensor([1,1,1,1], dtype=torch.float32)
y = net(x)
z = net.backward()
print(z)
  • 创建了一个Linear的实例net;
  • 传入输入张量x进行前向传播;
  • 调用net.backward()进行反向传播,得到输入x的梯度
  • 将结果打印输出。
tensor([[-0.8962, -0.9053, -1.5650, -0.3181],
        [-0.8962, -0.9053, -1.5650, -0.3181],
        [-0.8962, -0.9053, -1.5650, -0.3181],
        [-0.8962, -0.9053, -1.5650, -0.3181]], grad_fn=<MmBackward>)

3. 代码整合

# 导入必要的工具包
import torch

class Linear:
    def __init__(self, input_size, output_size):
        self.params = {}
        self.params['W'] = nn.Parameter(torch.randn(input_size, output_size, requires_grad=True))
        self.params['b'] = nn.Parameter(torch.randn(1, output_size, requires_grad=True))
        self.inputs = None
        self.grads = {}

    def __call__(self, x):
        return self.forward(x)

    def forward(self, inputs):
        self.inputs = inputs
        outputs = torch.matmul(self.inputs, self.params['W']) + self.params['b']

        return outputs

    def backward(self, grads=None):
        if grads == None:
            grads = torch.ones(self.params['W'].shape)
        self.grads['w'] = torch.matmul(self.inputs.T, grads)
        self.grads['b'] = torch.sum(grads, dim=0)
        return torch.matmul(grads, self.params['W'].T)

net = Linear(4, 2)
x = torch.tensor([1,1,1,1], dtype=torch.float32)
y = net(x)
z = net.backward()
print(z)

注意:

        本实验仅实现了线性模型的前向传播和反向传播部分,缺少了模型的训练部分,欲知后事如何,请听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1036598.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机网络】IP协议(下)

文章目录 1. 特殊的IP地址2. IP地址的数量限制3. 私有IP地址和公网IP地址私有IP为什么不能出现在公网上&#xff1f;解决方案——NAT技术的使用 4. 路由5. IP分片问题为什么要进行切片&#xff1f;如何做的分片和组装&#xff1f;16位标识3位标志13位片偏移例子 细节问题如何区…

一文带你玩转logo:含义、获取、使用以及2000多知名logo大图资源

大家好&#xff01;logo是我们非常熟悉的一种事物&#xff0c;但是我发现很多场合的logo使用并不规范、高效&#xff0c;所以今天六分成长来带着大家了解一下关于logo的方方面面。 一、什么是logo&#xff1f; logo不是某一些英文单词的缩写&#xff0c;是一个完整的单词&…

uniapp如何判断是哪个(微信/APP)平台

其实大家在开发uniapp项目的时候长长会遇到这样一个问题&#xff0c;就是针对某些小程序&#xff0c;没发去适配相关的功能&#xff0c;所以要针对不同的平台&#xff0c;进行不同的处理。 #ifdef &#xff1a; if defined 仅在某个平台编译 #ifndef &#xff1a; …

机器学习实验一:使用 Logistic 回归来预测患有疝病的马的存活问题

代码&#xff1a; import pandas as pd import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report import matplotlib.pyplot as plt def train(): # …

机器学习---神经元模型

1. 生物学的启示 神经元在结构上由细胞体、树突、轴突和突触四部分组成。 细胞体是神经元的主体&#xff0c;由细胞核、细胞质和细胞膜3部分组成。细胞体的外部是细胞膜&#xff0c;将 膜内外细胞液分开。由于细胞膜对细胞液中的不同离子具有不同的通透性&#xff0c;这使得膜…

XXE 漏洞及案例实战

文章目录 XXE 漏洞1. 基础概念1.1 XML基础概念1.2 XML与HTML的主要差异1.3 xml示例 2. 演示案例2.1 pikachu靶场XML2.1.1 文件读取2.1.2 内网探针或者攻击内网应用&#xff08;触发漏洞地址&#xff09;2.1.4 RCE2.1.5 引入外部实体DTD2.1.6 无回显读取文件 3. XXE 绕过3.1 dat…

【操作系统】线程、多线程

为什么要引入线程&#xff1f; 传统的进程只能串行的执行一系列程序&#xff0c;线程增加并发度。同一个进程分为多个线程。 线程是调度的基本单元&#xff0c;程序执行流的最小单位&#xff0c;基本的CPU执行单元。 进程是资源分配的基本单位。 线程的实现方式 用户级线程 代…

Unity入门教程(上)

七、运行游戏 再次保存我们的项目文件&#xff08;返回步骤四&#xff09;。保存完成后&#xff0c;让我们把游戏运行起来。 1&#xff0c;确认游戏视图标签页右上方的Maximize on Play图标处于按下状态&#xff0c;然后点击画面上方的播放按钮&#xff08;位于工具栏中间的播…

C++类模板学习

之前已经学习了函数模板&#xff0c;在这里&#xff0c; C函数模板Demo - win32 版_c编写的opc da demo_bcbobo21cn的博客-CSDN博客 下面学习类模板&#xff1b; VC6&#xff1b; 做一个星星类&#xff0c;Star&#xff1b; Star.h&#xff1b; #if !defined(AFX_STAR_H_…

(十二)VBA常用基础知识:worksheet的各种操作之sheet移动

当前sheet确认 把sheet1移动到sheet3前边 Sub Hello()10Worksheets("Sheet1").Move Before:Worksheets("Sheet3") End Sub3. 把sheet2移动到sheet1后边 Sub Hello()11Worksheets("Sheet2").Move after:Worksheets("Sheet1") End Sub…

MissionPlanner编译过程

环境 windows 10 mission planner 1.3.80 visual studio 2022 git 2.22.0 下载源码 (已配置git和ssh) 从github上克隆源码 git clone gitgithub.com:ArduPilot/MissionPlanner.git进入根目录 cd MissionPlanner在根目录下的ExtLibs文件下是链接的其它github源码&#xff0…

pymysql简介以及安装

视频版教程 Python操作Mysql数据库之pymysql模块技术 前面基础课程介绍了使用文件来保存数据&#xff0c;这种方式虽然简单、易用&#xff0c;但只适用于保存一些格式简单、数据量不太大的数据。对于数据量巨大且具有复杂关系的数据&#xff0c;当然还是推荐使用数据库进行保存…

79、SpringBoot 整合 R2DBC --- R2DBC 就是 JDBC 的 反应式版本, R2DBC 是 JDBC 的升级版。

★ 何谓R2DBC R2DBC 就是 JDBC 的 反应式版本&#xff0c; R2DBC 是 JDBC 的升级版。 R2DBC 是 Reactive Relational Database Connectivity (关系型数据库的响应式连接) 的缩写 反应式的就是类似于消息发布者和订阅者&#xff0c;有消息就进行推送。R2DBC中DAO接口中方法的…

Rust vs C++ 深度比较

Rust由于其强大的安全性受到大量关注&#xff0c;被认为C在系统编程领域最强大的挑战者。本文从语言、框架等方面比较了两者的优缺点。原文: Rust vs C: An in-depth language comparison Rust和C的比较是开发人员最近的热门话题&#xff0c;两者之间有许多相似之处&#xff0c…

Linux复习-安装与熟悉环境(一)

这里写目录标题 虚拟机ubuntu系统配置镜像Linux命令vi编辑器3个模式光标命令vi模式切换命令vi拷贝与粘贴命令vi保存和退出命令vi的查找命令vi替换命令 末行模式复制、粘贴、剪切gcc编译器 虚拟机 VMware16 官网下载&#xff1a;vmware官网 网盘下载&#xff1a; 链接&#xff…

共享文件夹设置密码怎么做?3招轻松为文件上锁!

“我们小组里建了一个共享文件夹&#xff0c;为了安全起见&#xff0c;想给文件夹设置一个密码&#xff0c;但是不知道应该怎么操作&#xff0c;有没有大佬可以教教我呀&#xff01;” 在我们的工作中&#xff0c;经常都会用到共享文件&#xff0c;这样可以让我们的工作方便快捷…

Jmeter接口测试

前言&#xff1a; 本文主要针对http接口进行测试&#xff0c;使用Jmeter工具实现。 Jmter工具设计之初是用于做性能测试的&#xff0c;它在实现对各种接口的调用方面已经做的比较成熟&#xff0c;因此&#xff0c;本次直接使用Jmeter工具来完成对Http接口的测试。 1.介绍什么是…

负载均衡技术全景:理论、实践与案例研究

在当今的互联网时代&#xff0c;随着用户数量的增长和业务规模的扩大&#xff0c;单一的服务器已经无法满足高并发、大流量的需求。为了解决这个问题&#xff0c;负载均衡技术应运而生。负载均衡可以将大量的网络请求分发到多个服务器上进行处理&#xff0c;从而提高系统的处理…

Qt-双链表的插入及排序

输入一个二维链表将其排序后转化成一维链表 要求&#xff1a;链表自定义不得使用模板库 链接&#xff1a;私信

Spring Cloud Gateway快速入门(一)——网关简介

文章目录 前言一、什么是网关1.1 gateway的特点1.2 为什么要使用gateway 二、使用 Nginx 实现网关服务什么是网关服务&#xff1f;为什么选择 Nginx 作为网关服务&#xff1f;如何使用 Nginx 实现网关服务&#xff1f;1. 安装 Nginx2. 配置 Nginx3. 启动 Nginx4. 测试网关服务 …