多层感知机(神经网络)

news2024/9/20 18:43:41

目录

  • 一、感知机(逻辑回归、二分类)定义:
  • 二、感知机不能解决XOR问题:
  • 三、多层感知机定义:
  • 四、训练过程:
    • 1.参数维度:
    • 2.常用激活函数:
      • 2.1Sigmoid激活函数:
      • 2.2Tanh激活函数:
      • 2.3ReLU激活函数:
    • 3.训练过程举例:******
  • 五、底层代码实现:
  • 六、Pytorch版代码:

一、感知机(逻辑回归、二分类)定义:

感知机其实就是一个逻辑回归模型,解决的是二分类问题。
逻辑回归模型其实就是加入了激活函数后的线性回归模型,加入激活函数的作用是使得输出层单一神经元的单一输出值限制在0和1之间,更适合于二分类问题。
在这里插入图片描述
感知机的训练过程同线性回归,只不过在线性回归的基础上输出之前加入了激活函数进行映射。

二、感知机不能解决XOR问题:

由于逻辑回归模型只能通过一条直线将样本数据划分为两个分类,因此对于下面的样本,无论如何训练模型,得到的决策边界都不能将样本正确的划分。
在这里插入图片描述
因此对于上述问题,应该如何解决?

答案是将多个逻辑回归模型堆叠多层,就能很好的解决上述问题,这就是多层感知机的由来。
在这里插入图片描述
其中黄色的逻辑回归模型、蓝色的逻辑回归模型将样本分别分为两类。最后通过灰色的逻辑回归模型使用蓝色黄色的输出特征作为输入将样本最终分为两类。

三、多层感知机定义:

多层感知机(神经网络)是逻辑回归和Softmax回归的推广,将逻辑回归和Softmax回归堆叠来解决原来单一模型不能解决的问题。其中隐藏层h1–h5为逻辑回归模型,用于根据输入特征分别解决一个二分类问题,输出层o1–o3组成一个Softmax回归模型,根据隐藏层输出的特征进行三分类问题的预测。

在这里插入图片描述

四、训练过程:

1.参数维度:

在这里插入图片描述

  • 输入层维度固定,由数据决定。
  • 隐藏层神经元个数是个超参数,因此隐藏层参数矩阵W、b的行数固定,由输入层维度决定,但是列数不固定,由神经元个数决定。
  • 输出层参数矩阵W、b的列数固定,由分类数目决定,但是行数不固定,由隐藏层神经元个数决定。

对于多隐藏层情况,每个隐藏层都有各自的W、b参数,其中隐藏层层数也是一个超参数。
在这里插入图片描述
注意每一层都是一个全连接层。全连接层概念

2.常用激活函数:

2.1Sigmoid激活函数:

在这里插入图片描述

2.2Tanh激活函数:

在这里插入图片描述

2.3ReLU激活函数:

在这里插入图片描述

3.训练过程举例:******

以十分类模型的一次训练过程为例,其中隐藏层一层,隐藏层神经元个数为256:
1.获取一个batch,里面包含batch_size张图片。
2.将batch_size张图片展成一维(例如24×24的图片展成784),获得输入维度为:batch_size×784×1(图片数×特征维度[784×1])。
3.隐藏层参数W维度计算为784×256,参数b维度计算为1×256。
4.每张图片的所有特征分别输入隐藏层的各个神经元hi及其激活函数计算预测值yi,一张图片的输出维度为256×1,隐藏层对整个batch的输出维度为batch_size×256×1,作为输出层输入(隐藏层相当于提取特征)。
5.输出层参数W维度计算为256×10,参数b维度计算为1×10。
6.将隐藏层输出特征矩阵作为隐藏层输入,输出层是一个softmax回归模型。
7.接下来的操作同softmax回归,每个1×256×1的特征分别作为输入计算预测值,输出维度1×10的预测结果。
8.整个batch中的输出组合成维度batch_size×10。
9.使用softmax回归将输出映射成概率,维度为batch_size×10,并且每行概率之和为1。
10.使用交叉熵损失函数计算batch中所有图片的概率损失,并取均值。
11.计算各个参数wmn、bn关于损失函数的梯度。
12.反向传播算法修改参数值。
13.输入下一个batch进行训练。

五、底层代码实现:

import torch
from torch import nn
from d2l import torch as d2l
# 1.获取数据,封装成一个dataloader
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

num_inputs, num_outputs, num_hiddens = 784, 10, 256# 实现的多层感知机,其中隐藏层数为1,隐藏层中神经元个数为256
# 2.初始化参数值
# 隐藏层
W1 = nn.Parameter(
    torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)# num_inputs×num_hiddens
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))# 1×num_hiddens
# 输出层
W2 = nn.Parameter(
    torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)# num_hiddens×num_outputs
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))# 1×num_outputs

params = [W1, b1, W2, b2]

# 3.实现激活函数
def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)

# 4.损失函数
loss = nn.CrossEntropyLoss()

# 5.实现模型
def net(X):
    X = X.reshape((-1, num_inputs))#将输入X拉成二维矩阵,即batch_size×num_inputs(这里把特征拉成一维)
    H = relu(X @ W1 + b1)# 隐藏层
    return (H @ W2 + b2)# 输出层

# 6.训练过程
num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

六、Pytorch版代码:

import torch
from torch import nn
from d2l import torch as d2l

# 1.网络架构
net = nn.Sequential(nn.Flatten(),# 将输入数据展平
                    nn.Linear(784, 256),# 隐藏层为全连接层
                    nn.ReLU(),# 隐藏层输出需经过激活函数
                    nn.Linear(256, 10)# 输出层也是全连接层
                    )
# 2.初始化参数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

# 3.训练过程
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=lr)

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1947525.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

排序XXXXXXXXX

信息学奥赛|常见排序算法总结(C+) - 腾讯云开发者社区-腾讯云 (tencent.com) https://cloud.tencent.com/developer/news/975232 常用序号层级排序 一、序号 序号Sequence Number,有顺序的号码,如数字序号…

数据结构: 链表回文结构/分割链表题解

目录 1.链表的回文结构 分析 代码 2.链表分割 ​编辑分析 代码 1.链表的回文结构 分析 这道题的难点是空间复杂度为O(1) 结合逆置链表找到链表的中间节点就可以解决了。 先找到链表的中间节点,再对中间节点的下一个节点进行逆置&…

代码随想录打卡第三十五天

代码随想录–动态规划部分 day 35 动态规划第三天 文章目录 代码随想录--动态规划部分一、卡码网46--携带研究材料二、力扣416--分割等和子集 一、卡码网46–携带研究材料 代码随想录题目链接:代码随想录 小明是一位科学家,他需要参加一场重要的国际科…

Leetcode—297. 二叉树的序列化与反序列化【困难】

2024每日刷题(148) Leetcode—297. 二叉树的序列化与反序列化 实现代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode(int x) : val(x), left(NULL), right(…

学习记录——day18 数据结构 树

树的存储 1、顺序存储 对于普通的二叉树,不适合存储普通的二叉树顶序存储,一般用于存储完全二叉树而言,如果使用顺序存储,会浪费大量的存储空间,因为需要给没有节点的位置留出空间,以便于后期的插入。 所以…

Springboot循环依赖的解决方式

Springboot循环依赖的解决方式 起因原因解决方案配置文件解决使用工具类获取bean还有一种我设想的方案 起因 今天重构代码时,发现之前的代码结构完全混乱,没有按照MVC分层思想去编写,很多业务逻辑写在了controller中,导致引用的很…

WebStorm中在Terminal终端运行脚本时报错无法加载文件进行数字签名。无法在当前系统上运行该脚本。有关运行脚本和设置执行策略的详细信息,请参阅

错误再现 我们今天要 在webstorm用终端运行脚本 目的是下一个openAPI的 前端请求代码生成的模块 我们首先从github上查看官方文档 我们根据文档修改 放到webstorm终端里执行 报错 openapi : 无法加载文件 C:\Users\ZDY\Desktop\多多oj\dduoj\node_modules\.bin\openapi.p…

LabVIEW多种测试仪器集成控制系统

在现代工业生产与科研领域,对测试设备的需求日益增长。传统的手动操作测试不仅效率低下,而且易出错。本项目通过集成控制系统,实现了自动化控制,降低操作复杂度和错误率,提高生产和研究效率。 系统组成与硬件选择 系…

逆向软件更新 x64dbg 加入 windows api 函数设断点插件

百度网盘链接:https://pan.baidu.com/s/1VaGP0rN8uTf8j_SzBgaEPg?pwd6666

Docker容器限制内存与CPU使用

文章目录 Docker 容器限制内存与 CPU 使用内存限额内存限制命令举例使用 `nginx` 镜像学习内存分配只指定 `-m` 参数的情况CPU 限制命令举例验证资源使用Docker 容器限制内存与 CPU 使用 在生产环境中,为了保证服务器不因某一个软件导致服务器资源耗尽,我们会限制软件的资源…

用uniapp 及socket.io做一个简单聊天app 2

在这里只有群聊,二个好友聊天,可以认为是建了一个二人的群聊。 const express require(express); const http require(http); const socketIo require(socket.io); const cors require(cors); // 引入 cors 中间件const app express(); const serv…

学术研讨 | 区块链治理与应用创新研讨会顺利召开

学术研讨 近日,国家区块链技术创新中心组织,长安链开源社区支持的“区块链治理与应用创新研讨会”顺利召开,会议围绕区块链治理全球发展现状、研究基础、发展趋势以及区块链行业应用创新展开研讨。北京大学陈钟教授做了“区块链治理与应用创…

我们的人生,向阳而生,去更远的地方,见更亮的光

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/140683410 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

qt SQLite学习记录

1. 查看qt中数据库的驱动的类型的支持 QStringList drivers QSqlDatabase::drivers();//获取qt中所支持的数据库驱动类型foreach(QString driver,drivers){qDebug()<<driver;}2. Qt SQL 模块包含的主要类的功能介绍 Qt SQL 模块包含了一些主要的类&#xff0c;用于在 …

自动驾驶系统开发与调试:车路云一体化无人驾驶挑战赛参赛体验

点击蓝字 关注我们 在过去的几年里&#xff0c;自动驾驶技术在全球范围内吸引了大量关注。其潜力不仅在于提升行车安全&#xff0c;而且还可以改变我们的出行方式和城市规划&#xff0c;提高交通运输效率。国际汽车工程师学会&#xff08;SAE&#xff09;根据不同自动驾驶程度&…

Linux中,MySQL数据库管理

使用MySQL数据库 查看数据库结构 MySQL是一套数据库管理系统&#xff0c;在每台MySQL服务器中&#xff0c;均支持运行多个数据库&#xff0c;每个数据库相当于一个容器&#xff0c;其中存放着许多表&#xff0c;如图2.1所示。 下面分别介绍查看数据库、表结构的相关操作语句。…

单片机原理及技术(四)—— C51语言程序设计基础(C51编程)

目录 一、C51语言中的数据类型与存储类型 1.1 数据类型 1.2 C51语言的扩展数据类型 1.3 数据存储类型 1.4 数据存储模式 二、C51语言的特殊寄存器及变量定义 2.1 特殊功能寄存器的C51语言定义 2.1.1 使用关键字定义sfr 2.1.2 使用头文件访问SFR 2.1.3 特殊功能寄存器…

《梦醒蝶飞:释放Excel函数与公式的力量》18.1 图表类型与设计

第18章&#xff1a;创建图表和数据可视化 18.1 图表类型与设计 Excel提供了多种图表类型&#xff0c;帮助用户以直观的方式展示数据。选择合适的图表类型和设计可以显著提高数据的可读性和理解度。以下将介绍常见的图表类型及其应用&#xff0c;并通过具体案例进行说明。 18.…

无人机制造工艺流程详解

一、需求分析 无人机制造的第一步是需求分析。这一阶段主要明确无人机的使用场景、功能要求、性能指标以及成本预算等。通过与客户或项目团队的深入沟通&#xff0c;确保对无人机的需求有全面而准确的理解。同时&#xff0c;也需要进行市场调研&#xff0c;了解同类型产品的特…

达梦数据库系列—31. 事务和锁

目录 事务 事务的状态 事务的四种隔离级别ACID 锁 锁粒度 锁等待与死锁 锁查看 锁处理 事务 数据库事务是指作为单个逻辑工作单元的一系列操作的集合。 事务的状态 NOT_START 未启动 ACTIVE 活动 LOCK_WAIT 锁等待 TRX4_PRE_COMMIT 预提交 事务ID(事务号)&…