人工智能算法工程师(中级)课程14-神经网络的优化与设计之拟合问题及优化与代码详解

news2024/11/16 6:00:20

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程14-神经网络的优化与设计之拟合问题及优化与代码详解。在机器学习和深度学习领域,模型的训练目标是找到一组参数,使得模型能够从训练数据中学习到有用的模式,并对未知数据做出准确预测。这一过程涉及到解决两种主要的拟合问题:欠拟合(Underfitting)和过拟合(Overfitting)。

文章目录

  • 一、拟合问题概述
    • 欠拟合现象
    • 过拟合现象
    • 解决策略
  • 二、正则化方法
    • 1. L1正则化
    • 2. L2正则化
  • 三、正则化参数的更新
  • 四、Dropout
  • 五、代码实现

一、拟合问题概述

在机器学习领域,拟合问题是指通过训练数据找到最佳模型参数,使得模型在未知数据上的表现尽可能好。拟合问题主要包括欠拟合和过拟合两种现象。

欠拟合现象

定义:欠拟合指的是机器学习模型在训练集上的表现不佳,无法充分学习到数据的内在规律,导致模型的预测能力低下。这就好比一个学生在考试中,由于知识掌握不牢固,对已知题目的解答都做不好,更不用说应对新题目了。
原因分析:
模型复杂度低:如果模型太简单,如用线性模型去拟合非线性的数据分布,那么模型就无法捕捉到数据中的复杂模式,就像用直尺去测量曲线长度一样,永远无法得到准确的结果。
训练数据不足:模型需要足够的数据来学习和概括数据的特性。如果数据量太少,模型可能没有机会接触到数据的全貌,就像从一本书中只读了几页就想理解整本书的内容一样困难。
特征选择不当:如果使用的特征与目标预测无关或相关性弱,模型就难以从中学习到有效的信息,相当于在解决问题时选择了错误的工具。

过拟合现象

定义:过拟合是指模型在训练数据上表现得过于出色,以至于对训练数据中的噪声或偶然性细节也进行了学习,这导致模型在面对未见过的数据时,泛化能力下降。这就像一个学生过分依赖于记忆特定的例题,而没有真正理解背后的原理,因此在遇到稍微变化的问题时就束手无策。
原因分析:
模型复杂度过高:如果模型过于复杂,如高阶多项式回归,它可能会过度适应训练数据中的每一个细节,包括噪声和异常值,而不是学习数据的普遍规律。
训练数据包含噪声:现实世界的数据往往带有噪声,如果模型试图学习这些噪声,就会导致过拟合。这类似于试图从嘈杂的环境中听清对话,噪声会干扰对真实信息的理解。
训练数据量不足:即使模型复杂度适中,但如果训练数据量不够,模型仍然可能过拟合。这是因为数据量不足时,模型可能会把偶然出现的模式误认为是普遍规律。

解决策略

增加模型复杂度:对于欠拟合,可以通过增加模型复杂度来提升模型的学习能力,如使用更高阶的多项式或更复杂的神经网络结构。
增加训练数据量:无论是欠拟合还是过拟合,增加训练数据量都能帮助模型更好地学习数据的分布,提高泛化能力。
特征工程:优化特征选择,确保模型能够基于有意义的特征进行学习。
正则化:使用L1或L2正则化等技术来限制模型复杂度,防止过拟合。
交叉验证:通过交叉验证来评估模型的泛化能力,确保模型不仅在训练数据上表现好,也能在未见数据上给出准确预测。
早停法:在训练过程中监控验证集的性能,一旦发现验证集上的性能不再提升,就停止训练,避免过拟合。
在这里插入图片描述

二、正则化方法

为了解决过拟合问题,通常采用正则化方法对模型进行约束。常见的正则化方法有L1正则化和L2正则化。

1. L1正则化

L1正则化的目标函数为:
J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 + α ∑ j = 1 n ∣ θ j ∣ J(\theta) = \frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})^2 + \alpha\sum_{j=1}^{n}|\theta_j| J(θ)=2m1i=1m(hθ(x(i))y(i))2+αj=1nθj
其中,第一项为损失函数,第二项为L1正则化项, α \alpha α为惩罚系数, θ j \theta_j θj为模型参数。

2. L2正则化

L2正则化的目标函数为:
J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 + α 2 ∑ j = 1 n θ j 2 J(\theta) = \frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})^2 + \frac{\alpha}{2}\sum_{j=1}^{n}\theta_j^2 J(θ)=2m1i=1m(hθ(x(i))y(i))2+2αj=1nθj2
其中,第一项为损失函数,第二项为L2正则化项, α \alpha α为惩罚系数, θ j \theta_j θj为模型参数。

三、正则化参数的更新

在优化目标函数时,我们需要对正则化参数进行更新。以下为L2正则化的参数更新公式:
θ j : = θ j − α ( 1 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) x j ( i ) + λ θ j ) \theta_j := \theta_j - \alpha\left(\frac{1}{m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})x_j^{(i)} + \lambda\theta_j\right) θj:=θjα(m1i=1m(hθ(x(i))y(i))xj(i)+λθj)
其中, λ = α m \lambda = \frac{\alpha}{m} λ=mα为正则化参数。
在这里插入图片描述

四、Dropout

Dropout是一种有效的正则化方法,通过在训练过程中随机丢弃部分神经元,来减少模型对特定训练样本的依赖。以下是Dropout的实现步骤:
(1)在训练过程中,按照一定概率随机丢弃神经元;
(2)在测试过程中,将所有神经元的输出乘以概率因子。

五、代码实现

以下是基于PyTorch的拟合问题及优化代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.linear(x)
# 生成数据
x = torch.randn(100, 1)
y = 3 * x + 2 + torch.randn(100, 1)
# 实例化模型
model = LinearRegression(1, 1)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01)  # L2正则化
# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
# 测试模型
model.eval()
with torch.no_grad():
    predicted = model(x).detach().numpy()
    print(f'预测值:{predicted}')

通过本文的介绍,相信大家对拟合问题及优化方法有了更深入的了解。在实际应用中,可根据数据特点选择合适的正则化方法,以提高模型的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1928477.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql(5.5)启动服务和环境配置

正常启动 参考:Javaweb基础之mysql回溯笔记(一) 总的来说就是在mysql的安装目录下,找到bin下面的msyqld.exe,双击即启动了mysql服务; 启动方式二 也可以直接找到windows的服务项进行启动,操作如下: 打开…

eclipse免安装版64位 2018版本

前言 eclipse是一个开放源代码的、基于Java的可扩展开发平台。就其本身而言,它只是一个框架和一组服务,用于通过插件组件构建开发环境。 一、下载地址 下载地址:http://source/download 选择如下图红色框文件内容下载 二、安装步骤 1、…

社交电商的新篇章:AI智能名片O2O商城小程序与传统微商的区别与融合

摘要 在数字经济蓬勃发展的今天,互联网技术的革新正以前所未有的速度重塑着商业格局。传统微商模式,尽管在初期借助社交媒体迅速崛起,但因其固有的局限性,如产品质量不一、营销手段单一、信任机制脆弱等,逐渐暴露出诸…

【实战场景】MongoDB迁移的那些事

【实战场景】MongoDB迁移的那些事 开篇词:干货篇【MongoDB迁移的方法】:1. 基于mongodump和mongorestore的迁移一、迁移前准备二、使用mongodump备份数据三、使用mongorestore还原数据四、注意事项 2. 基于MongoDB复制集的迁移一、迁移前准备二、配置新复…

Spring Boot整合Minio实现文件上传和读取

文章目录 一、简介1.分布式文件系统应用场景2.Minio介绍3.Minio优点 二、docker部署(windows系统)1.创建目录2.拉取镜像3.创建容器并运行4.访问控制台5.初始化配置 三、Spring Boot整合Minio1.创建demo项目2.引入依赖3.配置4.编写配置类5.MinIO工具类6.文…

ASP.NET Core----基础学习08----MVC中的属性路由

文章目录 1.MVC 中属性路由2.如果控制器名称与路由的第一级名称不一致3.指定读取的视图文件4.指定路由的一级 & 二级目录 1.MVC 中属性路由 step1: 在Startup.cs文件中设置仅使用UseMvc(不包含路由的设置) step2: 在控制器中…

实战案例:用百度千帆大模型API开发智能五子棋

前随着人工智能技术的迅猛发展,各种智能应用层出不穷。五子棋作为一款经典的棋类游戏,拥有广泛的爱好者。将人工智能技术与五子棋结合,不仅能提升游戏的趣味性和挑战性,还能展现AI在复杂决策问题上的强大能力。在本篇文章中&#…

如何使用 GPT?

​通过实例,来展示如何最好地使用 GPT。 生成文字 假设你在写一篇文章,需要在结尾加上这样一句:「California’s population is 53 times that of Alaska.」(加州的人口是阿拉斯加州的 53 倍)。 但现在你不知道这两个…

rancher单节点安装k8s

k3s 优点: 可用性 易于操作的轻量级部署模型 缺点: 与上游Kubernetes不同 RKE1 优点: 与上游Kubernetes紧密对齐 缺点: 严重依赖于 Docker RKE2 凭借 k3s 的优势和更紧密的上游协调,RKE2 将控制平面组件作为静态 pod 启动,由 kubelet 管理。 为了符合行业…

配置SMTP服务器的要点是什么?有哪些限制?

配置SMTP服务器安全性如何保障?如何高效配置服务器? SMTP作为电子邮件发送的核心协议,其配置对于确保邮件的成功传递和安全至关重要。AokSend将详细介绍配置SMTP服务器的关键要点,帮助读者建立一个高效、安全的邮件发送系统。 配…

LLM量化--AWQ论文阅读笔记

写在前面:近来大模型十分火爆,所以最近开启了一波对大模型推理优化论文的阅读,下面是自己的阅读笔记,里面对文章的理解并不全面,只将自己认为比较重要的部分摘了出来,详读的大家可以参看原文 原论文地址&am…

IIS只能访问根目录下的文件的解决方法

IIS只能访问根目录下的文件的解决方法 解决方法: 网站(右击) >> 高级设置 >>应用程序池 >> 选择(DefaultAppPool)

深入剖析 Android 开源库 EventBus 的源码详解

文章目录 前言一、EventBus 简介EventBus 三要素EventBus 线程模型 二、EventBus 使用1.添加依赖2.EventBus 基本使用2.1 定义事件类2.2 注册 EventBus2.3 EventBus 发起通知 三、EventBus 源码详解1.Subscribe 注解2.注册事件订阅方法2.1 EventBus 实例2.2 EventBus 注册2.2.1…

禹神:一小时快速上手Electron,前端Electron开发教程,笔记。一篇文章入门Electron

一、Electron是什么 简单的一句话,就是用htmlcssjsnodejs(Native Api)做兼容多个系统(Windows、Linux、Mac)的软件。 官网解释如下(有点像绕口令): Electron是一个使用 JavaScript、HTML 和 CSS 构建桌面…

Qt实现MDI应用程序

本文记录Qt实现MDI应用程序的相关操作实现 目录 1.MDM模式下窗口的显示两种模式 1.1TabbedView 页签化显示 1.2 SubWindowView 子窗体显示 堆叠cascadeSubWindows 平铺tileSubWindows 2.MDM模式实现记录 2.1. 窗体继承自QMainWindow 2.2.增加组件MdiArea 2.3.定义统一…

react自定义校验报错问题修复 ProFormText

1、以下是tsx组件 自定义校验告警导致表单无法提交问题修复 修改如下:

Mac Dock栏多屏幕漂移固定的方式

记录一下 我目前的版本是 14.5 多个屏幕,Dock栏切换的方式: 把鼠标移动到屏幕的中间的下方区域,触到边边之后,继续往下移,就能把Dock栏固定到当前屏幕了。

flutter 手写 TabBar

前言: 这几天在使用 flutter TabBar 的时候 我们的设计给我提了一个需求: 如下 Tabbar 第一个元素 左对齐,试了下TabBar 的配置,无法实现这个需求,他的 配置是针对所有元素的。而且 这个 TabBar 下面的 滑块在移动的时…

全开源TikTok跨境商城源码/TikTok内嵌商城/前端uniapp+后端+搭建教程

多语言跨境电商外贸商城 TikTok内嵌商城,商家入驻一键铺货一键提货 全开源完美运营 海外版抖音TikTok商城系统源码,TikToK内嵌商城,跨境商城系统源码 接在tiktok里面的商城。tiktok内嵌,也可单独分开出来当独立站运营 二十一种…

论文翻译:通过云计算对联网多智能体系统进行预测控制

通过云计算对联网多智能体系统进行预测控制 文章目录 通过云计算对联网多智能体系统进行预测控制摘要前言通过云计算实现联网的多智能体控制系统网络化多智能体系统的云预测控制器设计云预测控制系统的稳定性和一致性分析例子结论 摘要 本文研究了基于云计算的网络化多智能体预…