深度学习中的Dropout正则化:原理、代码实现与实际应用——pytorch框架下如何使用dropout正则化

news2025/1/8 12:49:17

目录

引言

一、导入包

二、dropout网络定义

三、创建模型,定义损失函数和优化器

四、加载数据

五、训练train

六、测试


引言

dropout正则化的原理相对简单但非常有效。它在训练神经网络时,以一定的概率(通常是在0.2到0.5之间)随机地将某些神经元的输出设置为零,即“关闭”这些神经元。这些“关闭”的神经元在整个训练过程中都不参与前向传播和反向传播。

这一过程有点类似于在每次训练迭代中从网络中删除一些神经元,然后在下一次迭代中再将它们添加回去。这种随机的“删除”和“添加”过程迫使网络不依赖于特定的神经元,从而提高了模型的泛化能力。

具体来说,dropout正则化有以下几个效果:

  1. 减少过拟合: 通过随机地关闭一些神经元,dropout可以减少神经网络对训练数据的过度拟合,使其更好地适应未见过的数据。

  2. 增加网络的鲁棒性: 由于每个神经元都有可能在任何时候被关闭,网络被迫学习对于任何输入都要保持稳健,而不是过于依赖某些特定的神经元。

  3. 防止协同适应: 在训练过程中,dropout可以防止神经元之间形成过于强烈的依赖关系,防止它们过分协同适应训练数据。

这一方法的关键在于,通过在训练期间随机地关闭一些神经元,模型不再依赖于任何一个特定的神经元,从而减少了过拟合的风险。在测试阶段,所有的神经元都被保留,但其输出值要按照训练时的概率进行缩放,以保持平衡。

一、导入包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

二、dropout网络定义

# 定义一个带有 dropout 正则化的简单神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 128)
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)  # 添加50%的dropout
        self.fc2 = nn.Linear(128, 64)
        self.dropout2 = nn.Dropout(0.3)  # 添加30%的dropout
        self.fc3 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x

三、创建模型,定义损失函数和优化器

# 创建模型,定义损失函数和优化器
model = Net()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

四、加载数据

# 生成虚拟数据
torch.manual_seed(42)
X_train = torch.rand((1000, 10))
y_train = torch.randint(0, 2, (1000, 1)).float()
X_test = torch.rand((200, 10))
y_test = torch.randint(0, 2, (200, 1)).float()

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

五、训练train

# 训练循环
for epoch in range(10):
    for inputs, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

六、测试

# 测试循环
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in test_dataloader:
        outputs = model(inputs)
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"测试准确率:{accuracy * 100:.2f}%")

然后我们就可以模仿上述的案例自己修改一些其中的参数进行网络训练和调整了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1245253.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于C#实现赫夫曼树

赫夫曼树又称最优二叉树,也就是带权路径最短的树,对于赫夫曼树,我想大家对它是非常的熟悉,也知道它的应用场景,但是有没有自己亲手写过,这个我就不清楚了,不管以前写没写,这一篇我们…

GeoTrust SSL数字安全证书介绍

一、GeoTrust OV证书的介绍 GeoTrust OV证书是由GeoTrust公司提供的SSL证书,它是一种支持OpenSSL的数字证书,具有更高的安全性和可信度。GeoTrust是全球领先的网络安全解决方案提供商,为各类用户提供SSL证书和信任管理服务。GeoTrust OV证书…

Docker+ Jenkins+Maven+git自动化部署

环境:Centos7 JDK1.8 Maven3.3.9 Git 2.40 Docker 20.10.17 准备工作: 安装Docker Centos7默认的yum安装的docker是1.13,版本太低,很多镜像都要Docker版本要求,升级Docker版本。 卸载已安装Docker: yum …

快速在WIN11中本地部署chatGLM3

具体请看智谱仓库github:GitHub - THUDM/ChatGLM3: ChatGLM3 series: Open Bilingual Chat LLMs | 开源双语对话语言模型 或者Huggingface:https://huggingface.co/THUDM/chatglm3-6b 1. 利用Anaconda建立一个虚拟环境: conda create -n chatglm3 pyt…

U盘启动制作工具Rufus

U盘启动制作工具Rufus 下载U盘启动制作工具Rufus,进入Rufus官网:http://rufus.ie/en/,打开之后往后滑动,找到download即可点击下载。 需要插入U盘 首先需要插入U盘,如果U盘有重要文件一定要备份,然后右键…

DB2中实现数据字段的拼接(LISTAGG() 与 xml2clob、xmlagg)

DB2中实现数据字段拼接(LISTAGG 与 xml2clob、xmlagg) 1. 使用函数LISTAGG()1.1 同oracle实现方式1.2 DB2中使用LISTAGG()1.2.1 关于DB2版本1.2.2 数据准备1.2.3 代码实现 2 解决DB2中关于 LISTAGG() 超长问题2.1 使用xmlagg xmlelement2.2 将xml标签去…

【论文阅读笔记】Smil: Multimodal learning with severely missing modality

Ma M, Ren J, Zhao L, et al. Smil: Multimodal learning with severely missing modality[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(3): 2302-2310.[开源] 本文的核心思想是探讨和解决多模态学习中的一个重要问题:在训练和测…

SpringBoot整合Redis,redis连接池和RedisTemplate序列化

SpringBoot整合Redis 1、SpringBoot整合redis1.1 pom.xml1.2 application.yml1.3 配置类RedisConfig,实现RedisTemplate序列化1.4 代码测试 2、SpringBoot整合redis几个疑问?2.1、Redis 连接池讲解2.2、RedisTemplate和StringRedisTemplate 3、RedisTemp…

DALSA.SaperaLT.SapClassBasic无法加载,试图加载格式不正确的程序,c#

情景:用c#wpf写DALSA线扫相机的项目,生成时不报错,运行到DALSA相关的代码就报错找不到dll(DALSA的技术支持没给到任何支持 ) 一.根据框架选择dll 如果是.net framework框架(比如说.net480)&am…

【LeetCode刷题-链表】--61.旋转链表

61.旋转链表 方法: 记给定的链表的长度为n,注意当向右移动的次数k>n时,仅需要向右移动k mod n次即可,因为每n次移动都会让链表变为原状 将给定的链表连接成环,然后将指定位置断开 /*** Definition for singly-linked list.*…

【OJ比赛日历】快周末了,不来一场比赛吗? #11.25-12.01 #17场

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…)比赛。本账号会推送最新的比赛消息,欢迎关注! 以下信息仅供参考,以比赛官网为准 目录 2023-11-25(周六) #9场比赛2023-11-26…

Linux(Centos)上使用crontab实现定时任务(定时执行脚本)

场景 Windows中通过bat定时执行命令和mysqldump实现数据库备份: Windows中通过bat定时执行命令和mysqldump实现数据库备份_mysqldump bat-CSDN博客 上面讲windows中使用bat实现定时任务的方式,如果是在linux上可以通过crontab实现。 cron是服务名称。…

案例018:基于微信小程序的实习记录系统

文末获取源码 开发语言:Java 框架:SSM JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder X 小程序…

Java核心知识点整理大全10-笔记

往期快速传送门: Java核心知识点整理大全-笔记_希斯奎的博客-CSDN博客文章浏览阅读9w次,点赞7次,收藏7次。Java核心知识点整理大全https://blog.csdn.net/lzy302810/article/details/132202699?spm1001.2014.3001.5501 Java核心知识点整理…

服务器流量包扣减规则

服务器买的流量包,一般指的是上行带宽,下行通常是不限的 上行和下行是针对服务器而言的 客户端上传文件给服务器,对服务器而言它是在下载,所以对服务器而言他是用的下行带宽(下行流量) 客户端从服务器下载文件,对服务器而言它是在上传,所以对服务器而言他是用的上行带宽(上行…

HTB Napper WriteUp

Napper 2023年11月12日 14:58:35User Nmap ➜ Napper nmap -sCV -A -p- 10.10.11.240 --min-rate 10000 Starting Nmap 7.80 ( https://nmap.org ) at 2023-11-12 13:58 CST Nmap scan report for app.napper.htb (10.10.11.240) Host is up (0.15s latency). Not shown: …

Linux文件查看命令

1.cat加上文件名 (因为所有文件内容都会打印到屏幕上,所以内容少时使用这个,总不能用cat来定义一本小说) 3.往文件中写入数据——cat加上>(重定向符)加上文件名,写完之后,按键 cat原本是把…

设计模式——行为型模式(一)

行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。 行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行…

Java基层卫生健康云综合管理(云his)系统源码

云HIS(Cloud-Based Healthcare Information System)是基于云计算的医院健康卫生信息系统。它运用云计算、大数据、物联网等新兴信息技术,按照现代医疗卫生管理要求,在一定区域范围内以数字化形式提供医疗卫生行业数据收集、存储、…

对tensor的处理函数:expand_as(尺寸扩展),nonzero(获取非零元素索引)

Tensor.expand_as(other) 扩展tensor到与other相同的尺寸 torch.nonzero(input, as_tupleFalse) 或 Tensor.nonzero() 返回input中非零元素的索引 indices 1)as_tuple False:返回的结果是tensor,z \times n,z为input中非零元素个…