机器学习复习(3)——分类神经网络与drop out

news2024/9/22 9:45:13

完整的神经网络

以分类任务为例,神经网络一般包括backbone和head(计算机视觉领域)

下面的BasicBlock不是一个标准的backbone,标准的应该是复杂的CNNs构成的

Classfier是一个标准的head,其中output_dim表示分类类别,一般写作num_classes

import torch  # 导入 torch 库
import torch.nn as nn  # 导入 torch 的神经网络模块
import torch.nn.functional as F  # 导入 torch 的函数式接口

# 定义一个基础的神经网络模块
class BasicBlock(nn.Module):  # 继承自 torch 的 Module 类
    def __init__(self, input_dim, output_dim):
        super(BasicBlock, self).__init__()  # 初始化父类

        # 构建一个序列模块,包含一个线性层和一个 ReLU 激活函数
        self.block = nn.Sequential(
# 线性层,输入维度为 input_dim,输出维度为 output_dim
            nn.Linear(input_dim, output_dim),  
            nn.ReLU(),  # ReLU 激活函数
        )

    def forward(self, x):
        x = self.block(x)  # 将输入数据 x 通过定义的序列模块
        return x  # 返回模块的输出


# 定义一个分类器神经网络
class Classifier(nn.Module):  # 继承自 torch 的 Module 类
    def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):
        super(Classifier, self).__init__()  # 初始化父类

        # 构建一个序列模块,包含若干个 BasicBlock 和一个线性输出层
        self.fc = nn.Sequential(
# 第一个 BasicBlock,将输入维度转换为隐藏层维度
            BasicBlock(input_dim, hidden_dim),  
# 根据 hidden_layers 数量添加多个 BasicBlock
            *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],  
# 线性输出层,将隐藏层维度转换为输出维度
            nn.Linear(hidden_dim, output_dim)  
        )

    def forward(self, x):
        x = self.fc(x)  # 将输入数据 x 通过定义的序列模块
        return x  # 返回模块的输出

对 *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)]的一个补充解释,“*”代表解压列表,例如A=[a,b,c],那么f(*A)=f(a,b,c)

在这里的具体意义是“便于控制隐藏层数量”,而其中的"_"代表不希望在循环中使用变量,这是一种通用的惯例,表明循环的目的不是对每个元素进行操作,而是为了重复某个操作特定次数。如果hidden_layers=3,这里的等价含义就是BasicBlock(hidden_dim, hidden_dim),BasicBlock(hidden_dim, hidden_dim),BasicBlock(hidden_dim, hidden_dim),——连续出现三次

dropout

Dropout层在神经网络层当中是用来干什么的呢?它是一种可以用于减少神经网络过拟合的结构。

如上图我们定义的网络,一共有四个输入x_i,一个输出y。Dropout则是在每一个batch的训练当中随机减掉一些神经元,而作为编程者,我们可以设定每一层dropout(将神经元去除的的多少)的概率,在设定之后,就可以得到第一个batch进行训练的结果:  

从上图我们可以看到一些神经元之间断开了连接,因此它们被dropout了!dropout顾名思义就是被拿掉的意思,正因为我们在神经网络当中拿掉了一些神经元,所以才叫做dropout层。
在进行第一个batch的训练时,有以下步骤:

  • 设定每一个神经网络层进行dropout的概率
  • 根据相应的概率拿掉一部分的神经元,然后开始训练,更新没有被拿掉神经元以及权重的参数,将其保留
  • 参数全部更新之后,又重新根据相应的概率拿掉一部分神经元,然后开始训练,如果新用于训练的神经元已经在第一次当中训练过,那么我们继续更新它的参数。而第二次被剪掉的神经元,同时第一次已经更新过参数的,我们保留它的权重,不做修改,直到第n次batch进行dropout时没有将其删除。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1421523.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

三步实现 Sentinel-Nacos 持久化

一、背景 版本:【Sentinel-1.8.6】 模式:【Push 模式】 参照官网介绍:生产环境下使用Sentinel ,规则管理及推送模式有以下3种模式: 比较之后,目前微服务都使用了各种各样的配置中心,故采用Pus…

手机屏幕生产厂污废水处理需要哪些工艺设备

随着手机行业的快速发展,手机屏幕生产厂的规模也越来越大,但同时也带来了大量的污废水排放问题。为了保护环境和人类的健康,手机屏幕生产厂需要采取适当的工艺设备来处理污废水。本文将介绍手机屏幕生产厂污废水处理所需的工艺设备。 首先&am…

【环境配置】安装了pytorch但是报错torch.cuda.is_availabel()=Flase

解决思路:import torch正常,说明torch包安装正常,但是不能和gpu正常互动,猜测还是pytroch和cuda的配合问题 1.查看torch包所需的cuda版本 我的torch是2.0.1,在现在是比较新的包,需要12以上的cuda支持&…

【算法与数据结构】198、213、337LeetCode打家劫舍I, II, III

文章目录 一、198、打家劫舍二、213、打家劫舍 II三、337、打家劫舍III三、完整代码 所有的LeetCode题解索引,可以看这篇文章——【算法和数据结构】LeetCode题解。 一、198、打家劫舍 思路分析:打家劫舍是动态规划的的经典题目。本题的难点在于递归公式…

Android开发之UI控件

TextView 实现阴影效果的textview android:shadowColor"#ffff0000" 设置阴影颜色为红色android:shadowRadius"3" 设置阴影的模糊程度为3android:shadowDx"10" 设置阴影在水平方向的偏移android:shadowDy"10" 设置阴影在竖直方向的偏…

iOS17使用safari调试wkwebview

isInspectable配置 之前开发wkwebview的页面的时候一直使用safari调试,毕竟jssdk交互还是要用这个比较方便,虽说用一个脚本插件没问题。不过还是不太方便。 但是这个功能突然到了iOS17之后发现不能用了,还以为又是苹果搞得bug,每…

Unity 状态模式(实例详解)

文章目录 简介示例1:基础角色状态切换示例2:添加更多角色状态示例3:战斗状态示例4:动画同步状态示例5:状态机管理器示例6:状态间转换的条件触发示例7:多态行为与上下文类 简介 Unity 中的状态模…

一个产品是怎么诞生的

一个产品的诞生,首先从假设需求开始,或者从玩耍的创客开始。 假设需求往往风险很大,你如果没有结合实际的生活经验或者是玩耍经验,凭空在脑子里想到一个东西,要把它创造出来,这样的东西极有可能会遭遇商业上…

ASP.NET Core 使用 SignalR 的简单示例

写在前面 ASP.NET SignalR 是一个开源代码库,简化了Web实时通讯方案,可以实时地通过服务端将信息同步推送到各个客户端,可应用于 需要从服务器进行高频更新的应用:包括游戏、社交网络、投票、拍卖、地图和GPS应用; 仪…

Servlet过滤器个监听器

过滤器和监听器 过滤器 什么是过滤器 当浏览器向服务器发送请求的时候,过滤器可以将请求拦截下来,完成一些特殊的功能,比如:编码设置、权限校验、日志记录等。 过滤器执行流程 Filter实例 package com.by.servlet;import jav…

2024年航海制造工程与海洋工程国际会议(ICNMEME2024)

一、【会议简介】 2024年航海制造工程与海洋工程国际会议(ICNMEME2024)旨在将研究人员、工程师、科学家和行业专业人士聚集在一个开放论坛上,展示他们在导航制造工程与海洋工程领域的激励研究和知识转移理念。然而,我们也认识到,工程师的未来…

【操作系统·考研】虚拟内存管理

1.概述 传统存储管理方式具有两个特征 一次性:作业必须一次性全部装入内存后,才能开始运行。驻留性:作业被装入内存后,就一直驻留在内存中,在其运行期间作业的任何部分都无法被换出。 显然,这两个特性非…

【深度学习】数据归一化/标准化 Normalization/Standardization

目录 一、实际问题 二、归一化 Normalization 三、归一化的类型 1. Min-max normalization (Rescaling) 2. Mean normalization 3.Z-score normalization (Standardization) 4.非线性归一化 4-1 对数归一化 4-2 反正切函数归一化 4-3 小数定标标准化(Demi…

山海鲸智慧教育方案:教育数据的未来

作为山海鲸可视化软件的开发者,我们深知数据可视化在教育领域的重要价值。山海鲸智慧教育解决方案正是在这样的背景下应运而生,致力于为教育行业提供高效、直观的数据可视化解决方案。 随着教育信息化的深入推进,教育数据呈爆炸式增长。如何…

嵌入式学习 Day14

一. 三个函数 1.strncpy char *strncpy(char *dest, const char *src, size_t n) // 比正常拷贝多了一个n { n < strlen(src) // 只拷贝前n个字符&#xff0c;最终dest中不会有\0 n strlen(src) // 正常拷贝 n > strlen(src) …

【Golang】ModbusRTU协议CRC16校验算法

CRC校验码是通过在数据后面附加一个短的校验序列来生成的&#xff0c;用于检测数据在传输过程中是否发生错误。CRC16是一种特定的CRC校验算法&#xff0c;它生成一个16位的校验码。 下面是使用Go语言实现CRC16校验算法的代码&#xff1a; package main import ("encoding…

【01】Linux 基本操作指令

带⭐的为重要指令 &#x1f308; 01、ls 展示当前目录下所有文件&#x1f308; 02、pwd 显示用户当前所在路径&#x1f308; 03、cd 进入指定目录&#x1f308; 04、touch 新建文件&#x1f308; 05、tree 以树形结构展示所有文件⭐ 06、mkdir 新建目录⭐ 07、rmdir 删除目录⭐…

某赛通电子文档安全管理系统 PolicyAjax SQL注入漏洞复现

0x01 产品简介 某赛通电子文档安全管理系统(简称:CDG)是一款电子文档安全加密软件,该系统利用驱动层透明加密技术,通过对电子文档的加密保护,防止内部员工泄密和外部人员非法窃取企业核心重要数据资产,对电子文档进行全生命周期防护,系统具有透明加密、主动加密、智能…

Linux--redhat9创建软件仓库

1.插入光盘&#xff0c;挂载镜像 模拟插入光盘: 点击:虚拟机-可移动设备-CD/DVD 设备状态全选&#xff0c;使用ISO影响文件选择当前版本镜像&#xff0c;点击确认。 2.输入: df -h 可以显示&#xff0c;默认/dev/sr0文件为光盘文件&#xff0c;挂载点为/run/media/root/镜像…

【操作系统·考研】文件系统基础

1.概述 文件(File)是以硬盘为载体的存储在计算机上的信息集合&#xff0c;文件可以是文本文档、图片、程序等&#xff0c;基本访问单元可以是字节或记录&#xff0c;可以长期储存在硬盘中&#xff0c;并允许可控制的进程间共享访问&#xff0c;还可以被组成成更复杂的结构。 在…