什么是正则化?Regularization: The Stabilizer of Machine Learning Models(中英双语)

news2024/12/18 23:08:09

正则化:机器学习模型的稳定器


1. 什么是正则化?

正则化(Regularization)是一种在机器学习模型训练中,通过约束模型复杂性以防止过拟合的技术
它的核心目标是让模型不仅在训练集上表现良好,还能在测试集上具有良好的泛化能力。


2. 为什么正则化起作用?

2.1 过拟合的本质

过拟合通常发生在模型参数过多、数据量不足或数据噪声较大时,模型学到了数据中的噪声和不相关的模式,从而导致泛化能力下降。

2.2 正则化的作用原理

正则化通过引入额外的约束条件来抑制模型的复杂性,限制其自由度,使得模型更倾向于学习数据的总体模式而非局部噪声。

数学原理
正则化通过在损失函数中添加正则项,改变了优化目标,从而约束模型的参数空间。以常见的线性回归为例:

  • 原始损失函数(最小化误差):
    L = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \mathcal{L} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 L=n1i=1n(yiy^i)2
  • 加入正则化后的损失函数:
    L reg = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 + λ R ( θ ) \mathcal{L}_{\text{reg}} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 + \lambda R(\theta) Lreg=n1i=1n(yiy^i)2+λR(θ)

其中:

  • ( R ( θ ) R(\theta) R(θ) ) 是正则项,用于约束模型参数 ( θ \theta θ )。
  • ( λ \lambda λ ) 是正则化强度的超参数,用于权衡数据拟合与正则化之间的关系。

3. 常见的正则化方法

3.1 参数正则化:L1 和 L2 正则化
  • L1 正则化(Lasso Regression)
    在损失函数中加入 ( L 1 L1 L1 ) 范数的约束:
    R ( θ ) = ∥ θ ∥ 1 = ∑ j = 1 p ∣ θ j ∣ R(\theta) = \|\theta\|_1 = \sum_{j=1}^p |\theta_j| R(θ)=θ1=j=1pθj

    • 优点:促使部分参数变为零,从而实现特征选择。
    • 缺点:在高维数据中可能会丢失部分信息。
  • L2 正则化(Ridge Regression)
    在损失函数中加入 ( L2 ) 范数的约束:
    R ( θ ) = ∥ θ ∥ 2 2 = ∑ j = 1 p θ j 2 R(\theta) = \|\theta\|_2^2 = \sum_{j=1}^p \theta_j^2 R(θ)=θ22=j=1pθj2

    • 优点:通过惩罚较大的参数值,抑制模型复杂性。
    • 缺点:不会稀疏参数,所有特征都会保留。

代码示例(以线性回归为例):

import numpy as np
from sklearn.linear_model import Ridge, Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# 模拟数据
np.random.seed(42)
X = np.random.rand(100, 5)
y = 3 * X[:, 0] + 2 * X[:, 1] + np.random.randn(100)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# L2 正则化(Ridge)
ridge = Ridge(alpha=1.0)  # alpha 控制正则化强度
ridge.fit(X_train, y_train)
y_pred_ridge = ridge.predict(X_test)

# L1 正则化(Lasso)
lasso = Lasso(alpha=0.1)
lasso.fit(X_train, y_train)
y_pred_lasso = lasso.predict(X_test)

print("Ridge MSE:", mean_squared_error(y_test, y_pred_ridge))
print("Lasso MSE:", mean_squared_error(y_test, y_pred_lasso))

3.2 数据增强(Data Augmentation)
  • 数据增强是通过对训练数据进行扩充(如图像翻转、裁剪、旋转等),使模型看到更多变种,从而提升泛化能力。
  • 常用于计算机视觉和自然语言处理领域。

代码示例(以 PyTorch 图像增强为例):

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# 数据增强
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
])

# 加载数据集
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 打印增强后的图像形状
for images, labels in train_loader:
    print(images.shape)  # (64, 3, 32, 32)
    break

3.3 Dropout
  • Dropout 是一种在训练过程中随机“丢弃”一部分神经元的正则化技术,用于防止神经网络过拟合。
  • 训练时,随机将一部分神经元的输出置为零;推理时,使用所有神经元,但缩放其输出。

数学原理
假设 Dropout 比例为 ( p p p ),每个神经元有 ( 1 − p 1-p 1p ) 的概率被激活:
输出 = 激活值 ⋅ 掩码 / ( 1 − p ) \text{输出} = \text{激活值} \cdot \text{掩码} / (1-p) 输出=激活值掩码/(1p)

代码示例

import torch
import torch.nn as nn

# 定义一个简单的网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.dropout = nn.Dropout(p=0.5)  # Dropout 概率为 0.5
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 使用 Dropout 的网络
model = SimpleNN()
print(model)

3.4 大模型中的正则化方法

在深度学习领域(尤其是 2022-2023 年的大模型训练),一些新的正则化方法逐渐被广泛应用:

  1. LayerNorm 和 WeightNorm

    • LayerNorm 对每一层进行归一化,减少梯度消失或爆炸问题。
    • WeightNorm 通过分离权重的幅度和方向,提升模型收敛速度。
  2. Label Smoothing

    • 通过在训练目标上引入少量噪声,避免模型过度自信。
      y ~ = ( 1 − ϵ ) ⋅ y + ϵ / K \tilde{y} = (1 - \epsilon) \cdot y + \epsilon / K y~=(1ϵ)y+ϵ/K
  3. 梯度裁剪(Gradient Clipping)

    • 限制梯度更新的幅度,避免梯度爆炸。
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  4. 正则化优化器

    • AdamW 是一种带权重衰减的优化器,直接在更新权重时加入 L2 正则化效果。

4. 正则化在大模型中的实际应用

以 GPT-3 或 BERT 等大语言模型的训练为例,正则化方法的组合应用非常重要:

  • 使用 LayerNormDropout 作为网络层内的正则化手段。
  • 在优化器中应用 AdamW,并设置适当的权重衰减参数。
  • 在大数据集上进行分布式训练,同时引入数据增强策略。

5. 总结

正则化技术是机器学习和深度学习中不可或缺的一部分,帮助模型在复杂场景下提升泛化能力并防止过拟合。
不同场景适合的正则化方法如下:

场景常用正则化方法
传统机器学习(线性模型)L1 正则化、L2 正则化
神经网络训练Dropout、数据增强
大模型训练(2022-2023)LayerNorm、AdamW、梯度裁剪、Label Smoothing

正则化方法的选择依赖于具体任务和模型的需求,但其核心思想始终是限制模型的复杂性,提升模型的稳定性和泛化能力。

Regularization: The Stabilizer of Machine Learning Models


1. What is Regularization?

Regularization is a set of techniques used in machine learning to constrain model complexity and prevent overfitting.
The primary goal of regularization is to ensure that the model performs well not only on the training data but also generalizes effectively to unseen test data.


2. Why Does Regularization Work?

2.1 The Nature of Overfitting

Overfitting happens when a model learns noise and irrelevant patterns in the training data, leading to poor generalization on new data. This is more common in cases with:

  • Insufficient training data
  • High model complexity
  • Noisy datasets
2.2 How Regularization Works

Regularization works by imposing constraints on the model’s complexity. This discourages it from fitting noise and forces it to focus on learning the underlying patterns in the data.

Mathematical Insight:
By adding a regularization term to the loss function, we effectively change the optimization objective, which restricts the parameter space.

For example, in linear regression:

  • Original loss function:
    L = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \mathcal{L} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 L=n1i=1n(yiy^i)2
  • Regularized loss function:
    L reg = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 + λ R ( θ ) \mathcal{L}_{\text{reg}} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 + \lambda R(\theta) Lreg=n1i=1n(yiy^i)2+λR(θ)

Where:

  • ( R ( θ ) R(\theta) R(θ) ) is the regularization term that penalizes complex models.
  • ( λ \lambda λ ) controls the trade-off between fitting the data and regularization strength.

3. Common Regularization Techniques

3.1 Parameter Regularization: L1 and L2 Regularization
  • L1 Regularization (Lasso)
    Adds the ( L 1 L1 L1 )-norm of the parameters to the loss function:
    R ( θ ) = ∥ θ ∥ 1 = ∑ j = 1 p ∣ θ j ∣ R(\theta) = \|\theta\|_1 = \sum_{j=1}^p |\theta_j| R(θ)=θ1=j=1pθj

    • Advantages: Encourages sparsity, making some parameters zero. Useful for feature selection.
    • Disadvantages: May lose some information in high-dimensional data.
  • L2 Regularization (Ridge)
    Adds the ( L 2 L2 L2 )-norm of the parameters to the loss function:
    R ( θ ) = ∥ θ ∥ 2 2 = ∑ j = 1 p θ j 2 R(\theta) = \|\theta\|_2^2 = \sum_{j=1}^p \theta_j^2 R(θ)=θ22=j=1pθj2

    • Advantages: Shrinks large parameter values, reducing model complexity.
    • Disadvantages: Does not produce sparse parameters; retains all features.

Code Example (Linear Regression with L1 and L2):

import numpy as np
from sklearn.linear_model import Ridge, Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# Generate synthetic data
np.random.seed(42)
X = np.random.rand(100, 5)
y = 3 * X[:, 0] + 2 * X[:, 1] + np.random.randn(100)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Ridge (L2) Regularization
ridge = Ridge(alpha=1.0)
ridge.fit(X_train, y_train)
y_pred_ridge = ridge.predict(X_test)

# Lasso (L1) Regularization
lasso = Lasso(alpha=0.1)
lasso.fit(X_train, y_train)
y_pred_lasso = lasso.predict(X_test)

print("Ridge MSE:", mean_squared_error(y_test, y_pred_ridge))
print("Lasso MSE:", mean_squared_error(y_test, y_pred_lasso))

3.2 Data Augmentation

Data augmentation expands the training dataset by applying transformations (e.g., flips, rotations, cropping) to existing data, increasing model robustness and improving generalization.

Example (Image Augmentation in PyTorch):

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# Define data augmentation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
])

# Load dataset with augmentation
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Print augmented image shape
for images, labels in train_loader:
    print(images.shape)  # Example: (64, 3, 32, 32)
    break

3.3 Dropout

Dropout randomly deactivates a subset of neurons during training, reducing reliance on specific neurons and preventing co-adaptation.

Mathematical Insight:
For a dropout rate ( p p p ), each neuron’s output is retained with probability ( 1 − p 1-p 1p ). During inference, the full network is used but scaled by ( 1 − p 1-p 1p ).

Code Example:

import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.dropout = nn.Dropout(p=0.5)  # 50% dropout
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = SimpleNN()
print(model)

3.4 Advanced Regularization Techniques for Large Models

With the advent of large-scale models (2022-2023), new regularization techniques have been widely adopted:

  1. LayerNorm and WeightNorm

    • LayerNorm normalizes activations across features within a layer.
    • WeightNorm separates weight vectors into magnitude and direction, improving optimization stability.
  2. Label Smoothing
    Prevents overconfidence in predictions by softening the target distribution:
    y ~ = ( 1 − ϵ ) ⋅ y + ϵ / K \tilde{y} = (1 - \epsilon) \cdot y + \epsilon / K y~=(1ϵ)y+ϵ/K

  3. Gradient Clipping
    Limits the magnitude of gradients to prevent exploding gradients:

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  4. AdamW Optimizer
    Combines the Adam optimizer with weight decay for improved regularization.


4. Regularization in Large Model Training

For models like GPT-3 and BERT, regularization involves combining multiple techniques:

  • LayerNorm and Dropout to stabilize training and reduce overfitting.
  • AdamW with appropriate weight decay settings.
  • Label Smoothing for classification tasks to prevent overconfidence.
  • Gradient Clipping to handle gradient explosion in deep networks.

5. Conclusion

Regularization is crucial for building robust machine learning models. The right choice of technique depends on the specific task and model requirements. Below is a summary of common regularization techniques:

ScenarioRegularization Methods
Traditional ML (linear models)L1, L2 regularization
Neural Network TrainingDropout, Data Augmentation
Large Model TrainingLayerNorm, AdamW, Label Smoothing

By constraining model complexity, regularization ensures models are stable, generalizable, and less prone to overfitting.

后记

2024年12月14日15点55分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2261832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设计规规范:【App 配色】

文章目录 引言I App 配色组成色彩象征 & 联想II 知识扩展设计流程图UI设计交互设计UI交互设计引言 设计规范,保持设计一致性,提高设计效率。宏观上对内统一,管理与合作变得容易。 按类型管理颜色、文本样式、图标、组件(symbol)。 蓝湖设计规范云 https://lanhuapp.co…

[maven]使用spring

为了更好理解springboot,我们先通过学习spring了解其底层。 这里讲一下简单的maven使用spring框架入门使用。因为这一块的东西很多都需要联合起来后才好去细讲,本篇通过spring-context大致地介绍相关内容。 注意:spring只是一个框架&#xff…

Unity性能优化---使用SpriteAtlas创建图集进行批次优化

在日常游戏开发中,UI是不可缺少的模块,而在UI中又使用着大量的图片,特别是2D游戏还有很多精灵图片存在,如果不加以处理,会导致很高的Batches,影响性能。 比如如下的例子: Batches是9&#xff0…

transformer学习笔记-位置编码

在transformer学习笔记-自注意力机制(1)学习原理的时候,我们提到: 将句子从“苹果梨”,改成“梨苹果”,最终的到的新苹果和新梨,竟然是一样的,因为苹果和梨两个向量调换顺序后,对应计…

【Unity3D】实现UGUI高亮引导点击

Unity版本2019.4.0f1 Personal <DX11> using UnityEngine; using UnityEngine.UI;public class GuideMask : MonoBehaviour, ICanvasRaycastFilter {public Canvas canvas;public Transform guideTargetTrans;public Image image;private Vector3 guideTargetWorldPos;pr…

Springboot3.x配置类(Configuration)和单元测试

配置类在Spring Boot框架中扮演着关键角色&#xff0c;它使开发者能够利用Java代码定义Bean、设定属性及调整其他Spring相关设置&#xff0c;取代了早期版本中依赖的XML配置文件。 集中化管理&#xff1a;借助Configuration注解&#xff0c;Spring Boot让用户能在一个或几个配…

SpringBoot增删改查导入导出操作【模板】

SpringBoot增删改查导入导出操作【模板】 文章目录 SpringBoot增删改查导入导出操作【模板】前期数据库操作IDEA上进行操作1. 创建 Spring Boot 项目2. 项目结构3. pom.xml文件4. 配置数据库连接并进行测试5. 创建实体类6. 创建 MyBatis Mapper7. 创建服务层8. 创建控制器9. 启…

mfc140.dll是什么东西?mfc140.dll缺失的几种具体解决方法

mfc140.dll是Microsoft Foundation Classes&#xff08;MFC&#xff09;库中的一个动态链接库&#xff08;DLL&#xff09;文件&#xff0c;它是微软基础类库的一部分&#xff0c;为Windows应用程序的开发提供了丰富的类库和接口。MFC库旨在简化Windows应用程序的开发过程&…

探索Starship:一款用Rust打造的高性能终端

在终端的世界里&#xff0c;效率和美观往往并行不悖。今天&#xff0c;我们要介绍的是一款名为Starship的终端工具&#xff0c;它以其轻量级、高颜值和强大的自定义功能&#xff0c;赢得了众多开发者的青睐。 安装 任选一种方式进行安装 Windows &#x1fa9f; # scoop scoo…

2024年NSSCTF秋季招新赛-WEB

The Beginning F12看源码&#xff0c;有flag http标头 黑吗喽 题目说要在发售时的0点0分&#xff0c;所以添加标头data Date: Tue, 20 Aug 2024 00:00:00 GMT然后改浏览器头 User-Agent: BlackMonkey曲奇就是Cookie cookieBlackMonkey这个一般就是Referer Referer:wukon…

TQ15EG开发板教程:使用SSH登录petalinux

本例程在上一章“创建运行petalinux2019.1”基础上进行&#xff0c;本例程将实现使用SSH登录petalinux。 将上一章生成的BOOT.BIN与imag.ub文件放入到SD卡中启动。给开发板插入电源与串口&#xff0c;注意串口插入后会识别出两个串口号&#xff0c;都需要打开&#xff0c;查看串…

windos系统安装-mysql 5.7 zip压缩包教程

一, 安装包下载 在mysql官网上下载mysql5.7版本的压缩包 官方网址: https://dev.mysql.com/downloads/mysql/5.7.html#downloads选择历史版本 选择系统和数据库版本下载 下载完成后解压到安装的目录 二, 新增数据目录,配置文件, 配置环境变量 新建data文件夹用于存放数据库…

js 获取屏幕高度和宽度的几种方式

1、document.documentElement.clientHeight 屏幕可视区域高度&#xff0c;文档的根元素&#xff08;通常是 <html> 元素&#xff09;的高度&#xff0c;但会受到CSS样式的影响。 实际应用&#xff1a;对于H5的移动端&#xff0c;希望video元素在全屏状态下占满整个手机屏…

Tree-of-Counterfactual Prompting for Zero-Shot Stance Detection

论文地址&#xff1a;Tree-of-Counterfactual Prompting for Zero-Shot Stance Detection - ACL Anthologyhttps://aclanthology.org/2024.acl-long.49/ 1. 概述 立场检测被定义为对文本中立场态度的自动推断。根据 Biber 和 Finegan (1988) 的定义&#xff0c;立场包含两个主…

css基础-认识css

什么是css css是一个样式表&#xff0c;是对html的一种装饰&#xff0c;它决定了浏览器如何显示html元素&#xff0c;例如&#xff1a; h1 {color:blue; //文字颜色是蓝色font-size:12px; //字体大小为12像素 }上段css代码就是对HTML 中 <h1>标签的修饰&#xff1b;所以…

【Unity功能集】TextureShop纹理工坊(二)图层(下)

项目源码&#xff1a;后期发布 索引 图层渲染绘画区域图层Shader 编辑器编辑模式新建图层设置当前图层上、下移动图层删除图层图层快照 图层 在PS中&#xff0c;图层的概念贯穿始终&#xff08;了解PS图层&#xff09;&#xff0c;他可以称作PS最基础也是最强大的特性之一。 …

云计算HCIP-OpenStack02

书接上回&#xff1a; 云计算HCIP-OpenStack01-CSDN博客 7.OpenStack核心服务 7.1Horizon&#xff1a;界面管理服务 Horizon提供了OpenStack中基于web界面的管理控制页面&#xff0c;用户或者是管理员都需要通过该服务进行OpenStack的访问和控制 界面管理服务需要依赖于keyston…

Word2Vec:将词汇转化为向量的技术

文章目录 Word2Vec来龙去脉分层Softmax负采样 Word2Vec 下面的文章纯属笔记&#xff0c;看完后不会有任何收获&#xff0c;如果想理解这两种优化技术&#xff0c;给大家推荐一篇博客&#xff0c;讲的很好&#xff1a; 详解-----分层Softmax与负采样 来龙去脉 word2vec,即将词…

电商商品详情API接口(item get)数据分析上货

电商商品详情API接口&#xff08;item get&#xff09;在数据分析与商品上货方面发挥着重要作用。以下是对这两个方面的详细探讨&#xff1a; 一、数据分析 数据源获取&#xff1a; 商品详情API接口提供了丰富的数据源&#xff0c;包括商品的标题、价格、库存、描述、图片、用…

如何将你的 Ruby 应用程序从 OpenSearch 迁移到 Elasticsearch

作者&#xff1a;来自 Elastic Fernando Briano 将 Ruby 代码库从 OpenSearch 客户端迁移到 Elasticsearch 客户端的指南。 OpenSearch Ruby 客户端是从 7.x 版 Elasticsearch Ruby 客户端分叉而来的&#xff0c;因此代码库相对相似。这意味着当将 Ruby 代码库从 OpenSearch 迁…