【chapter29】【PyTorch】[Regularization】

news2025/1/13 3:05:37

前言:

     前面讲了Overfitted,这里重点讲解一下如何防止

Overfitting ,以及其中的方案之一 Regularization

 模型的参数量,模型的表达能力远超模型本身复杂度.

与之对应的是奥卡姆剃刀原理:

      如何用最简单的方法得到最好的效果

       找到关键的部分,简单灵活地去处理,你会发现,成功并不那么复杂。

目录:

    1 more data

    2: constraint model complexity

    3 Dropout

    4 data  argumentation

   5   Early Stopping

   6  Regularization


一:  More data

     增加Train Data 数据集大小

 有的时候,我们项目中没有办法获得那么大的数据集,

这个时候可以通过Gain 或者编码器   合成数据集,达到增加数据集容量

二   constraint model complexity(限制模型复杂度)

     2.1  shallow: 

     单隐藏层神经网络就是典型的浅层(shallow)神经网络,即只包含一层隐含层


 

 2.2 regularization

        正规化 分为L1 正规化,和L2 正规化

     在机器学习里面是一种最常用的方案


三 Dropout  

            dropout(随机失活):dropout是通过遍历神经网络每一层的节点,然后通过对该层的神经网络设置一个keep_prob(节点保留概率),即该层的节点有keep_prob的概率被保留,keep_prob的取值范围在0到1之间。通过设置神经网络该层节点的保留概率,使得神经网络不会去偏向于某一个节点(因为该节点有可能被删除),从而使得每一个节点的权重不会过大,有点类似于L2正则化,来减轻神经网络的过拟合。

1    首先随机(临时)删掉网络中一半的隐藏神经元(以dropout rate为0.5为例),输入输出神经元保持不变

2  然后把输入x 通过修改后的网络前向传播,然后把得到的损失结果通过修改的网络反向传播。一  小批(这里的批次batch_size由自己设定)训练样本执行完这个过程后,在没有被删除的神经元(C_r)上按照随机梯度下降法更新对应的参数(w,b)


重复以下过程:
sub-1、恢复被删掉的神经元(C_d)(此时被删除的神经元保持原样,而没有被删除的神经元已经有所更新),因此每一个mini- batch都在训练不同的网络。


sub-2、从隐藏层神经元中随机选择一个一半大小的子集临时删除掉(备份被删除神经元的参数)。
sub-3、对一小批训练样本,先前向传播然后反向传播损失并根据随机梯度下降法更新参数(w,b) (没有被删除的那一部分参数得到更新C_r,删除的神经元参数保持被删除前的结果)。
 


四  data  argumentation

     数据增强,这在图像处理里面常用

随机旋转一般情况下是对输入图像随机旋转[0,360)
随机裁剪是对输入图像随机切割掉一部分
色彩抖动指的是在颜色空间如RGB中,每个通道随机抖动一定的程度。
是指在图像中随机加入少量的噪声。该方法对防止过拟合比较有效
水平翻转
竖直翻转


五 Early Stopping

     使用Validation Data 做一个提前终止


六  Regularization

     J(\theta)=-\frac{1}{m}\sum_i y_iln \hat{y_i}+(1-y_i)ln (1-\hat{y_i})+\lambda \sum_j |\theta_j|

  

正规化有两种,一种是L1 正规化(有降维的效果)

另一种是L2正规化

 对于L2 正规化有默认的参数可以直接配置

# -*- coding: utf-8 -*-
"""
Created on Wed Apr 26 16:56:05 2023

@author: chengxf2
"""

import torch
from torch import optim
from torch import nn


# 先定义一个三层感知机,激活函数使用Relu(小于0的,都转换为0)
class MLP(nn.Module):
    def __init__(self, in_dim, hid_dim1, hid_dim2, out_dim):
        super(MLP, self).__init__()
        #使用Sequential快速搭建三层感知机
        self.layer = nn.Sequential(
            # 第一层
            nn.Linear(in_dim, hid_dim1),
            nn.ReLU(),
            nn.Linear(hid_dim1, hid_dim2),
            nn.ReLU(),
            nn.Linear(hid_dim2, out_dim),
            nn.ReLU()
            )
    def forward(self, x):
        y = self.layer(x)
        return y
 

def train():
    print("\n step1 init model")
    learning_rate =1e-3
    net =MLP(28*28, 300, 200, 10)
    optimizer = optim.SGD(net.parameters(),lr =learning_rate, weight_decay=0.01) #L2正规化
    criteon = nn.CrossEntropyLoss()

    print("\n step2 forward")
    data = torch.randn(10, 28*28)
    output = net(data)
    label = torch.Tensor([1, 0, 4, 7, 9, 3, 4, 5, 3, 2]).long()
   
    
    print("\n step3 backward")
    loss = criteon(output, label)

    # 清空梯度,在每次优化前都要进行此操作
    optimizer.zero_grad()
    # 损失的反向传播
    loss.backward()
    # 利用优化器进行梯度更新
    optimizer.step()
    
if __name__ == "__main__":
     train()

针对L1 正规化,PyTorch 需要自己实现,方案如下

# -*- coding: utf-8 -*-
"""
Created on Thu Apr 27 15:23:56 2023

@author: chengxf2
"""

import torch
from torch import optim
from torch import nn

# 先定义一个三层感知机,激活函数使用Relu(小于0的,都转换为0)
class MLP(nn.Module):
    def __init__(self, in_dim, hid_dim1, hid_dim2, out_dim):
        super(MLP, self).__init__()
        #使用Sequential快速搭建三层感知机
        self.layer = nn.Sequential(
            # 第一层
            nn.Linear(in_dim, hid_dim1),
            nn.ReLU(),
            nn.Linear(hid_dim1, hid_dim2),
            nn.ReLU(),
            nn.Linear(hid_dim2, out_dim),
            nn.ReLU()
            )
    def forward(self, x):
        y = self.layer(x)
        return y
net =MLP(100, 50, 25, 10)
optimizer = optim.SGD(net.parameters(),lr=1e-3)
regularization_loss = 0.0
criteon = nn.CrossEntropyLoss()

for param in net.parameters():
    L1= torch.abs(param)
    regularization_loss +=L1
    
    
data = torch.randn(10, 28*28)    
logits = net(data)
target = torch.Tensor([1, 0, 4, 7, 9, 3, 4, 5, 3, 2]).long()
classify_loss = criteon(logits, target)

loss = classify_loss+0.01*regularization_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()

参考

Dropout的深入理解(基础介绍、模型描述、原理深入、代码实现以及变种)_dropout模型_ㄣ知冷煖★的博客-CSDN博客
数据增强(Data Argumentation)_左小田^O^的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/468303.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SWAT水文模型】SWAT水文模型建立及应用第三期:土壤库建立(待更新)

SWAT水文模型建立及应用:土壤库建立 1 简介2 土壤数据下载2.1 数据下载方式2.1.1 世界土壤数据库HWSD数据2.1.2 中国土壤数据库 2.2 数据下载 3 土壤数据的准备3.1 SWAT土壤数据库参数3.2 土壤质地转化3.3 土壤参数的提取3.4 其他变量的提取3.5 土壤类型分布图的处理…

回顾 | Pre VS Code Day - 用 GitHub Codespaces 构建 OpenAI 应用实战

编辑:Alan Wang 排版:Rani Sun 微软 Reactor 为帮助广开发者,技术爱好者,更好的学习 .NET Core, C#, Python,数据科学,机器学习,AI,区块链, IoT 等技术,将每周三到周六&a…

85.qt qml-炫酷烟花粒子特效(支持多种爆炸模式、爆炸阴影、背景场景)

效果如下所示: 截图如下所示: 实现内容如下所示: 1.实现多个爆炸效果2.爆炸的时候增加光度阴影效果3.由于场景有湖面,所以还需要增加一个倒影粒子组首先我们来学习下,该示例中所需要常用的类型点 1.如何更改粒子生命周期时的颜色变换动画 方法有两种。 1.1通过colorTable和si…

详谈Android进程间的大数据通信机制:LocalSocket

前言 说起Android进行间通信,大家第一时间会想到AIDL,但是由于Binder机制的限制,AIDL无法传输超大数据。 比如我们在之前文章《WebRtc中是如何处理视频数据的?》提到的我们可以得到WebRtc的视频数据,这时候我们如果有…

C++类和对象(4)

C类和对象 1.拷贝构造函数1.1 概念1.2. 特征1.2.1. 拷贝构造函数构造函数的一种重载形式;1.2.2. 拷贝构造函数的参数只能有一个,是对本类类型对象的引用,不能是传值调用,编译会直接报错,或者是直接进入死循环。1.2.3. …

wvp开发环境搭建

代码下载地址 代码下载地址 https://gitee.com/pan648540858/wvp-GB28181-pro.git 开发工具 采用jetbrain idea 利用开发工具下载代码 文件-新建-来自版本控制的项目 url是上面的代码下载链接,点击克隆即可 下图是已经克隆并打开的代码 安装依赖环境 安装redi…

基于html+css的图展示44

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

Linux套接字编程

在上一篇博客中我们对网络中一些基本概念进行了简单阐述,这一篇博客我们来对套接字编程的内容进行初步了解。 目录 1.引入 2.UDP协议 2.1通信两端流程 2.1.1服务端流程 2.1.2客户端流程 2.2套接字相关操作接口 2.2.1创建套接字 2.2.2为套接字绑定地址信息 …

SSL证书周期变为90天? 锐成让您轻松应对行业新规

3月3日,谷歌在其“Move Forward, Together”栏目中,称已向CA/B论坛发起了投票提案,建议将公共TLS(也称为SSL)证书的最长有效期从398天减少到90天。值得注意的是,即便CA/B论坛没有通过这一提议,谷…

【C语言】函数讲解(下)

【C语言】函数讲解(下) 1.函数的声明和定义1.1函数声明1.2函数定义 2.函数的嵌套调用和链式访问2.1嵌套调用2.2链式访问 3.函数递归3.1什么是递归3.2递归的两个必要条件3.2.1练习13.2.2练习2 3.3递归与迭代3.3.1练习13.3.2练习2 所属专栏:C语…

Android Jetpack—LiveData

1.LiveData LiveData是Android Jetpack包提供的一种可观察的数据存储器类,它可以通过添加观察者被其他组件观察其变更。不同于普通的观察者,LiveData最重要的特征是它具有生命周期感知能力,它遵循其他应用组件(如 Activity、Frag…

软件测试—进阶篇

软件测试—进阶篇 🔎根据测试对象划分界面测试可靠性测试容错性测试文档测试兼容性测试易用性测试安装卸载测试安全性测试性能测试内存泄漏测试 🔎根据是否查看代码划分黑盒测试白盒测试灰盒测试 🔎根据开发阶段划分单元测试集成测试系统测试…

mulesoft MCIA 破釜沉舟备考 2023.04.27.25 (易错题)

@[TOC](mulesoft MCIA 破釜沉舟备考 2023.04.27.25 (易错题)) 1. According to MuleSoft, which deployment characteristic applies to a microservices application architecture? A. Services exist as independent deployment artifacts and can be scaled independently…

ABeam Insight | 智能制造系列(6):虚拟/增强现实(VR/AR)×智能制造

虚拟现实(VR)和增强现实(AR)的概念早在20世纪60年代就被提出,但由于当时的技术水平无法满足相关应用的需求,这些概念并没有引起广泛关注。直到近年来随着计算机技术的飞速发展,虚拟现实和增强现…

python+nodejs+php+springboot+vue高校教室自习室预约管理系统

建立的自习室预约管理系统用户使用浏览器就可以对其进行访问,管理员在操作上面能够方便管理,因此用户和管理员能够方便对这个系统进行操作。论文全面介绍系统数据库,功能设计和业务流程设计。数据库能够存储自习室预约管理系统需要的数据。 …

Leanback(1)-播放控制栏下添加新的行

我们要在播放控制栏下面加入下面一行。 这个就是标准的row。 leanback的原理 Android Leanback结构源码简析 - 简书 我们知道Row用来提供数据,row可以通过一个ObjectAdapter来管理和提供数据 我们知道presenter是一个负责将数据绑定到视图上的对象,它可以…

基于STM32的智能语音垃圾桶设计

一. 系统设计及框图: 本设计整体功能如下: 1. 超声波感应到有人靠近时语音提示“垃圾放置请分类”。 2. 检测垃圾筒时是否满,当满时语音提示“垃圾桶已满”。 3. 光传感器检测,指示灯指示。 4. 语音识别不同的垃圾类型。 二.…

前端程序员的职业发展规划与路线——ChatGPT的回答

文章目录 一、前端程序员的职业规划是?回答1: 作为一个前端开发程序员,您的职业发展路线可能如下:回答2:作为前端开发程序员,您的职业发展路线可能如下:回答3: 你的职业发展路线可能…

ASEMI代理ADI亚德诺ADM3051CRZ-REEL7车规级芯片

编辑-Z ADM3051CRZ-REEL7芯片参数: 型号:ADM3051CRZ-REEL7 显性状态:78 mA 隐性状态:10 mA 待命状态:275μA CANH输出电压:4.5V CANL输出电压:2V 差动输出电压:3V 输入电压…

【AI生产力工具】Upscale.media:用AI技术提升照片质量,让你的作品更出色

文章目录 简介一、Upscale.media是什么?二、如何使用Upscale.media?三、总结 简介 在如今的数字时代,图片已经成为我们日常生活中不可或缺的一部分,从社交媒体到电子商务网站,从广告宣传到个人生活,都需要…