动手学深度学习——多层感知机

news2025/1/12 1:50:17

1. 感知机

感知机本质上是一个二分类问题。给定输入x、权重w、偏置b,感知机输出:

以猫和狗的分类问题为例,它本质上就是找到下面这条黑色的分割线,使得所有的猫和狗都能被正确的分类。

与线性回归和softmax的不同点:

  • vs 线性回归:输出的都是一个数,但线性回归输出的是实数,而感知机输出的是离散的分类。
  • vs softmax: softmax是一个多分类(如果有n个分类,softmax就会输出n个元素),而感知机只输出一个元素。

感知机存在的问题: 它只能产生线性分割面,对于XOR(异或)函数,无法拟合(一条线不论怎么分割,都无法将绿色和红色分类正确)。

2. 多层感知机(MLP)

对于上面单层感知机的问题,一个改进思想是:一层函数如果做不了,就用多层函数来做,而多层就带来了网络,用不同层解决不同的问题,多层配合来解决更复杂的问题。

可以使用蓝线对所有数据进行x轴方向的正负分类,再使用黄线对所有数据进行y轴方向的正负分类,最后再将两次分类结果进行xor运算就能得到结果。

多层感知机使用隐藏层和激活函数来得到非线性模型。

在softmax基础上多了隐藏层。可选超参:

  • 隐藏层数
  • 每个隐藏层的宽度,通常选择2的若干次冥作为层的宽度

这两个参数的选择取决于输入和输出的复杂度

对复杂的输入,输入维度一般比较高,输出一般会比较少,有两种处理办法:

  1. 做单隐藏层,把模型做平,层的大小设大一点
  2. 做多隐藏层,把模型做深,层的大小可以设小一点,每层的维度逐步减少(如果每层维度都高,则会导致模型太大)

复杂输入到简单输出本质上是一个信息压缩的过程,多层逐步压缩能避免一次压缩太大导致信息损失太严重,例如:128->64->32->16->8
也可以先expand,从128->256->64->32->16->8

3. 激活函数

作用:在神经网络中引入非线性,可以理解为一个开关,当输入信号超过一定阀值时,神经元会被激活并产生输出,而未超过阀值时神经元将会被抑制。

在没有激活函数的情况下,神经网络只能表示线性映射,无法处理复杂的非线性关系。激活函数的作用就是线性结果映射到一个非线性的输出,以帮助神经网络更好的适应输入数据,提高非线性拟合能力。

举例:一个邮件过滤模型中的神经元,负责对输入邮件的特征(长度、关键词等)进行加权求和,但这个结果只是一个连续的数值我们交

激活函数不能是线性函数,否则会变成单层感知机,依然会存在线性分割面无法处理XOR的问题。

激活函数主要作用于隐藏层。

激活函数的几种选择:

  1. sigmoid: 对于任意输入x,都能投影到0~1区间内。

  2. tanh(x): 将输入投影到[-1,1]区间内

  1. ReLU: 就是一个Max函数(常用),特点是计算很快,相比前面基于指数运算的sigmoid和tanh函数都快很多(一次指数运算要100个时钟周期)

对ReLU函数求导,小于等于0时都是0,大于0时都是1,最终结果就是一个二分类。

4. 代码实现

4.1 初始化参数

我们将实现一个具有单隐藏层的多层感知机, 这个隐藏层包含128个隐藏单元。

对于每一层我们都要记录一个权重矩阵和一个偏置向量,并指定requires_grad=True来记录参数梯度。

import torch
from torch import nn
from d2l import torch as d2l

num_inputs, num_outputs, num_hiddens = 784, 10, 128

W1 = nn.Parameter(torch.randn(
    num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(
    num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

通常,我们选择2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。

4.2 加载数据集

这里继续使用Fashion-MNIST图像分类数据集。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

4.3 激活函数

Relu函数的实现比较简单,就是一个max函数的调用, 它将输入的负值部分截断为0,保留正值部分不变。

def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)
  • torch.zeros_like(X): 创建了一个与X具有相同形状的全零张量a。
  • torch.max(X, a): 对于输入X中的每个元素,如果它是正值,则该元素保留不变;如果它是负值,则将其替换为0。

4.4 模型

def net(X):
    X = X.reshape((-1, num_inputs))    
    H = relu(X@W1 + b1)  # 隐藏层,这里“@”代表矩阵乘法
    return (H@W2 + b2)   # 输出层
  1. 使用reshape将输入的二维图像转换为一个长度为num_inputs=784的向量;
  2. 用ReLu函数对隐藏层的线性输出进行激活,得到输出张量H;
  3. 最后,由张量H和权重矩阵W2进行矩阵乘法操作,将偏置向量b2加到结果上,得到预测输出结果。

4.5 损失函数

这里直接使用pytorch中内置的交叉熵损失函数。

loss = nn.CrossEntropyLoss(reduction='none')

4.6 训练

多层感知机的训练过程与softmax的训练过程完全相同,可以直接调用之前定义过的train_ch3函数。

# 将迭代周期数设置为10,并将学习率设置为0.1.
num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

训练过程中的模型损失和精度的收敛变化:

epoch: 1, loss: 1.1021366075515746, test_acc: 0.7544
epoch: 2, loss: 0.6142196039199829, test_acc: 0.8004
epoch: 3, loss: 0.5257990721384684, test_acc: 0.8061
epoch: 4, loss: 0.4842481053034465, test_acc: 0.7988
epoch: 5, loss: 0.4575055497487386, test_acc: 0.8266
epoch: 6, loss: 0.4389862974802653, test_acc: 0.8382
epoch: 7, loss: 0.42252545185089113, test_acc: 0.8443
epoch: 8, loss: 0.40933472124735515, test_acc: 0.8458
epoch: 9, loss: 0.3975078603744507, test_acc: 0.8467
epoch: 10, loss: 0.38488629398345947, test_acc: 0.8527

基于之前softmax模型上定义的预测函数,在测试数据集上使用这个模型做验证:

predict_ch3(net, test_iter)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1661281.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

服务丢在tomcat中启动war包,需要在tomcat中配置Java环境吗?

一般来说,部署在 Tomcat 上的 WAR 包启动时不需要在 Tomcat 中单独配置 Java 环境,因为 Tomcat 启动本身就需要依赖 Java 环境。以下是确保 Tomcat 正常运行与部署 WAR 包的基本步骤: 安装 Java 环境: 首先,确保你的系…

Web Component fancy-components

css-doodle 组件库 fancy-components 组件库使用 yarn add fancy-components使用: import { FcBubbles } from fancy-components new FcBubbles() //要用哪个就new哪个 new 这里可能会报错eslink,eslintrc.js中处理报错 module.exports {rules: {no-new: off} …

Python运维之定时任务模块APScheduler

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 目录 定时任务模块APScheduler 一、安装及基本概念 1.1、APScheduler的安装 1.2、涉及概念 1.3、APScheduler的工作流程​编辑 二、配置调度器 …

luceda ipkiss教程 68:通过代码模板提高线路设计效率

在用ipkiss设计器件或者线路时,经常需要输入: from ipkiss3 import all as i3那么有什么办法可以快速输入这段代码呢?这里就可以利用Pycharm的 live template功能,只需要将文件:ipkiss.xml (luceda ipkiss教程 68&…

P9420 [蓝桥杯 2023 国 B] 子 2023 / 双子数

蓝桥杯2023国B A、B题 A题 分析 dp问题 根据子序列:2,20,202,2023分为4个状态; 当前数字为2时,处于dp[0],或者和dp[1]结合成dp[2]; 当前数字为0时,和dp[0]结合成dp[…

数据结构学习/复习12

一、排序概念与应用 二、插入排序 三、希尔排序 当间隔数为1时则为插入排序 1.一组一组排 2.多组并排 3.间隔数变化直至为1 四、性能测速代码

XSS-Labs 靶场通过解析(下)

前言 XSS-Labs靶场是一个专门用于学习和练习跨站脚本攻击(XSS)技术的在线平台。它提供了一系列的实验场景和演示,帮助安全研究人员、开发人员和安全爱好者深入了解XSS攻击的原理和防御方法。 XSS-Labs靶场的主要特点和功能包括:…

判断字符是否唯一——力扣

面试题 01.01. 判定字符是否唯一 已解答 简单 相关标签 相关企业 提示 实现一个算法,确定一个字符串 s 的所有字符是否全都不同。 示例 1: 输入: s "leetcode" 输出: false 示例 2: 输入: s "abc" 输出: true…

若依生成树表和下拉框选择树表结构(在其他页面使用该下拉框输入)

1.数据库表设计 生成树结构的主要列是id列和parent_id列,后者指向他的父级 2.来到前端代码生成器页面 导入你刚刚写出该格式的数据库表 3.点击编辑,来到字段 祖籍列表是为了好找到直接父类,不属于代码生成器方法,需要后台编…

LeetCode例题讲解:876.链表的中间结点

给你单链表的头结点 head ,请你找出并返回链表的中间结点。 如果有两个中间结点,则返回第二个中间结点。 示例 1: 输入:head [1,2,3,4,5] 输出:[3,4,5] 解释:链表只有一个中间结点,值为 3 。…

一篇详解Git版本控制工具

华子目录 版本控制集中化版本控制分布式版本控制 Git简史Git工作机制Git和代码托管中心局域网互联网 Git安装基础配置git的--local,--global,--system的区别 创建仓库方式1git init方式2git clone git网址 工作区,暂存区,本地仓库…

2023盘古石晋级赛 移动终端取证 WP

9. 根据容恨寒的安卓手机分析,MAC的开机密码是[答案:asdcz] 到这里火眼就寄了,盘古石 启动! 10. 根据容恨寒的安卓手机分析,苹果手机的备份密码前4位是[答案:1234] 11. 根据魏文茵苹果手机分析&#xff0c…

基于JAVAEE的停车场管理系统(论文 + 源码)

【免费】基于JAVAEE的停车场管理系统.zip资源-CSDN文库https://download.csdn.net/download/JW_559/89292324 基于JAVAEE的停车场管理系统 摘 要 如今,我国现代化发展迅速,人口比例急剧上升,在一些大型的商场,显得就格外拥挤&…

算法设计与分析 动态规划/回溯

1.最大子段和 int a[N]; int maxn(int n) {int tempa[0];int ans0;ansmax(temp,ans);for(int i1;i<n;i){if(temp>0){tempa[i];}else tempa[i];ansmax(temp,ans);}return ans; } int main() {int n,ans0;cin>>n;for(int i0;i<n;i) cin>>a[i];ansmaxn(n);co…

Spring添加注解读取和存储对象

5大注解 Controller 控制器 Service 服务 Repository 仓库 Componet 组件 Configuration 配置 五大类注解的使用 //他们都是放在同一个目录下&#xff0c;不同的类中 只不过这里粘贴到一起//控制器 Controller public class UserController {public void SayHello(){System.ou…

在新页面中跳转到指定 div容器位置

要在打开新的页面时跳转到指定 div&#xff0c;我们需要结合 HTML、JavaScript 和后端技术来实现。以下是两种常见的方法&#xff1a; 使用 URL 参数传递目标 div 信息 HTML (新页面): 在新页面的链接中&#xff0c;添加参数来指示目标 div 的 id&#xff0c;例如&#xff1a;…

Android GPU渲染SurfaceFlinger合成RenderThread的dequeueBuffer/queueBuffer与fence机制(2)

Android GPU渲染SurfaceFlinger合成RenderThread的dequeueBuffer/queueBuffer与fence机制&#xff08;2&#xff09; 计算fps帧率 用 adb shell dumpsys SurfaceFlinger --list 查询当前的SurfaceView&#xff0c;然后有好多行&#xff0c;再把要查询的行内容完整的传给 ad…

im8mm 网络卡死 Rx packets:1037578 errors:66 dropped:0 overruns:66 frame:0

1&#xff1a;网络接收数据包异常 2&#xff1a;问题复现 问题在进行网络数据包同吞吐量测试的时候出现的。同时发现&#xff0c;在使用iperf2测试时&#xff0c;是不会出现网络中断卡死的情况&#xff0c;使用 iperf3时才会出现此问题 指令(下面的指令运行在PC2上面&#xff…

Linux 安装JDK和Idea

安装JDK 下载安装包 下载地址&#xff1a; Java Downloads | Oracle (1) 使用xshell 上传JDK到虚拟机 (2) 移动JDK 包到/opt/environment cd ~ cd /opt sudo mkdir environment # 在 /opt下创建一个environment文件夹 ls# 复制JDK包dao /opt/environment下 cd 下载 ls jd…

聊天室项目思路

发起群聊&#xff1a; 从好友表选取人发送到服务器&#xff0c;服务器随机生成不重复的群号&#xff0c;存储在数据库&#xff0c;同时建立中间表&#xff0c;处理用户与群聊的关系 申请入群&#xff1a; 输入群号&#xff0c;发消息给服务器&#xff0c;服务器查询是否存在…