深度学习_15_过拟合欠拟合

news2024/11/16 21:40:19

过拟合和欠拟合
在这里插入图片描述
过拟合和欠拟合是训练模型中常会发生的事,如所要识别手势过于复杂,如五角星手势,那就需要更改高级更复杂的模型去训练,若用比较简单模型去训练,就会导致模型未能抓住手势的全部特征,那简单模型估计只能抓住五角星的其中一个角做特征,那么这个简单模型很可能就会将三角形与五角星混淆,这就是所谓欠拟合

若用识别五角星的复杂模型去识别三角形也是不行的,模型会过拟合,即学习了过多不重要的部分,可能会把三角形每条边所画的时间也当作学习的内容,即便我们人知道什么时候画哪条边都无所谓。

过拟合和欠拟合的表现都是模型的识别精度不够,所以要想判断模型是过拟合还是欠拟合,除了理论还是要多调试

如:
在这里插入图片描述
合适的模型应该是抛物线,上述左边是欠拟合,右边是过拟合

在这里插入图片描述
训练集和测试集

值得注意的是训练集和测试集必须是分开的,训练模型用训练集,一定不能让测试集污染模型

模型过拟的特征即对见过的数据集表现非常好,而对从未见过的模型表现非常差,若不把训练,测试集完全分开,最后的模型过拟合将无法被发现

实例:

完整代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)


# 以下是你原来的训练函数,没有修改
def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]


def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss()
    input_shape = train_features.shape[-1]
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)

    # 用于存储训练和测试损失的列表
    train_losses = []
    test_losses = []

    for epoch in range(num_epochs):
        train_loss, train_acc = d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        test_loss = evaluate_loss(net, test_iter, loss)

        # 将当前的损失值添加到列表中
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"  训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}")
    print(net[0].weight)
    # 假设 train_losses 和 test_losses 是已经计算出的损失值列表
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 100)  # 设置y轴的范围从0.01到100
    plt.show()


# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

分部讲解如下:

问题实例:
在这里插入图片描述

产生高斯分布随机数x并按上述式子生成训练集和验证集y,并对生成的y再添加一些杂音处理
注意:训练集一定要打乱,不要排序,排序会让训练效果大打折扣,如果训练数据是按照某种特定顺序排列的,那么模型可能会学习到这种顺序并在这个过程中引入偏差,导致模型在未见过的新数据上的泛化能力下降,打乱训练集的目的通常是为了防止模型学习到训练数据中的任何顺序依赖性,这样可以提高模型在随机或未见过的新数据上的泛化能力。

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)

计算损失函数,并不会更新迭代模型,所以用他来测试模型测试集损失

def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

训练函数,将X和对应y放在一起,即是进行模型迭代更新,又能计算模型训练损失,测试损失并绘制相应图形

def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss()  # 默认取平均损失
    input_shape = train_features.shape[-1]  # 模型大小取train_features最后一项大小
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])  # 整体数据集分成<= 10批次

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)  # 梯度下降算法

    # 用于存储训练和测试损失的列表
    train_losses = []
    test_losses = []

    for epoch in range(num_epochs):
        train_loss, train_acc = d2l.train_epoch_ch3(net, train_iter, loss, trainer)  # 训练迭代模型
        test_loss = evaluate_loss(net, test_iter, loss)

        # 将当前的损失值添加到列表中
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"  训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}")
    print(net[0].weight)  # 输出训练好的模型
    # 假设 train_losses 和 test_losses 是已经计算出的损失值列表
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 100)  # 设置y轴的范围从0.01到100
    plt.show()

主函数

# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

利用上述实例验证欠拟合和过拟合以及正常拟合

在这里插入图片描述
上述函数对应真正的模型为:

true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

当然还有一些杂质,可忽略

那么可知预训练模型取四个维度就能做到正常拟合,而取二十个维度就是过拟合,取四个以下维度就是欠拟合

过拟合即取二十维度效果:

在这里插入图片描述
可以看出损失在下降到最低点的时候还会有上升
在这里插入图片描述
这是因为学完主要四个维度后又将本应是0的维度也学习了,也就是学习了无用的杂质。
在这里插入图片描述

欠拟合二维度模型效果:

在这里插入图片描述
损失很大,这也是没办法,毕竟还有很多重要维度没有学习上,本质上是模型过小

正常拟合四维度模型效果:

在这里插入图片描述
正常拟合的模型在损失到达最低点后便不再上升,训练出来的模型与真实数据也及其接近

在这里插入图片描述
正常拟合才是我们训练模型的期望状态

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1479680.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【OCR识别】使用OCR技术还原加密字体文字

文章目录 1. 写在前面2. 页面分析3. 字符知识4. 加密分析 【作者主页】&#xff1a;吴秋霖 【作者介绍】&#xff1a;Python领域优质创作者、阿里云博客专家、华为云享专家。长期致力于Python与爬虫领域研究与开发工作&#xff01; 【作者推荐】&#xff1a;对JS逆向感兴趣的朋…

The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits

The Era of 1-bit LLMs: All Large Language Models Are in 1.58 Bits 相关链接&#xff1a;arxiv、github 关键字&#xff1a;1-bit LLMs、BitNet、模型压缩、能耗效率、模型性能 摘要 近期的研究&#xff0c;例如BitNet&#xff0c;正在为1-bit大型语言模型&#xff08;LLMs…

网站文章被百度快速收录的工具

百度是中国最主要的搜索引擎之一&#xff0c;对于网站管理员来说&#xff0c;网站文章被百度快速收录是至关重要的&#xff0c;因为这直接影响着文章的曝光和网站的流量。然而&#xff0c;许多网站管理员都会问一个常见的问题&#xff1a;文章百度收录需要几天&#xff1f;在这…

ISP代理是什么?怎么用?

在跨境出海业务中&#xff0c;代理IP对于您的在线任务至关重要&#xff0c;尤其是对于那些运行多个帐户的人来说。为您的帐户选择正确类型的代理对于确保帐户安全非常重要&#xff0c;劣质的IP容易使账号遭受封号风险。IPFoxy的多种代理IP类型应用范围各有侧重&#xff0c;其中…

项目实现json字段

有些很复杂的信息&#xff0c;我们一般会用扩展字段传一个json串&#xff0c;字段一般用text类型存在数据库。mysql5.7以后支持json类型的字段&#xff0c;还可以进行sql查询与修改json内的某个字段的能力。 1.json字段定义 ip_info json DEFAULT NULL COMMENT ip信息, 2.按…

C语言学习笔记(二)

C语言学习 学习笔记(一) 学习笔记(二&#xff09; 文章目录 C语言学习一、C语言中的数据类型进制二进制八进制十六进制进制转换表 单位换算寻址 数据类型基本类型整数类型整数的有符号和无符号实数类型字符型 构造类型指针类型空类型总结 常量直接常量符号常量转义符 符号常量…

Java配置49-nginx 反向代理 sftp 服务器

1. 背景 后端服务需要通过部署在跳板机上的 nginx 访问一个外网的 SFTP 服务器。 2. 方法 nginx从 1.9.0 开始&#xff0c;新增加了一个stream模块&#xff0c;用来实现四层协议的转发、代理或者负载均衡等。 首先检查 nginx 版本信息及是否安装了 stream 模块。 进入 ngi…

java程序员面试笔试宝典答案,java面试框架问题

目录 由于文档内容过多&#xff0c;共计有500页&#xff0c;因此为了避免影响到大家的阅读体验&#xff0c;在此只以截图展示部分内容&#xff0c;详细完整版的可以在文末获取&#xff01; 部分内容展示 深入浅出索引&#xff08;上&#xff09; 索引的常见模型InnoDB 的索引…

回溯 Leetcode 332 重新安排行程

重新安排行程 Leetcode 332 学习记录自代码随想录 给你一份航线列表 tickets &#xff0c;其中 tickets[i] [fromi, toi] 表示飞机出发和降落的机场地点。请你对该行程进行重新规划排序。 所有这些机票都属于一个从 JFK&#xff08;肯尼迪国际机场&#xff09;出发的先生&a…

使用R语言进行主成分和因子分析

一、数据描述 数据来源2013年各地区水泥制造业规模以上企业的各主要经济指标&#xff0c;原始数据来源于2014年&#xff08;《中国水泥统计年鉴》&#xff09;&#xff0c;试对用主成分和因子进行经济效益评价。 地区,企业个数&#xff08;亿元&#xff09;,流动资产合计&…

亚信安慧AntDB之国密算法介绍

近年来&#xff0c;为摆脱对国外技术和产品的过度依赖&#xff0c;建设行业网络安全环境&#xff0c;增强我国行业信息系统安全、可靠的能力&#xff0c;国家有关机关和监管机构站在国家安全和长远战略的高度提出了“推动国密算法应用实施、加强行业安全可控”的要求。 密码算…

感谢信∣企企通再获肯定,中国煤科【天玛智控】SRM项目成功上线,推动煤矿供应链智能化高效协同发展

近日&#xff0c;煤矿智能无人化开采技术引领者【北京天玛智控科技股份有限公司】&#xff08;以下简称“天玛智控”&#xff09;携手企企通打造的SRM数字化采购平台成功上线。系统上线后&#xff0c;实现了天玛智控与供应商之间的信息共享和业务协作&#xff0c;提升采购业务效…

解决android studio build Output中文乱码

1.效果如下所示&#xff1a; 代码运行报错的时候&#xff0c;Build Output报的错误日志中中文部分出现乱码&#xff0c;导致看不到到底报的什么错。 2.解决办法如下&#xff1a; 点击Android studio开发工具栏的Help-Edit Custom VM Options....&#xff0c;Android studio会…

Node.js中的并发和多线程处理

在Node.js中&#xff0c;处理并发和多线程是一个非常重要的话题。由于Node.js是单线程的&#xff0c;这意味着它在任何给定时间内只能执行一个任务。然而&#xff0c;Node.js的事件驱动和非阻塞I/O模型使得处理并发和多线程变得更加高效和简单。在本文中&#xff0c;我们将探讨…

继承-学习2

this关键字&#xff1a;指向调用该方法的对象&#xff0c;一般我们是在当前类中使用this关键字&#xff0c;所以我们常说代表本类对象的引用 super关键字&#xff1a;代表父类存储空间的标识(可看作父类对象的引用) 父类&#xff1a; package ven;public class Fu {//父类成员…

Jenkins笔记(一)

个人学习笔记&#xff08;整理不易&#xff0c;有帮助点个赞&#xff09; 笔记目录&#xff1a;学习笔记目录_pytest和unittest、airtest_weixin_42717928的博客-CSDN博客 目录 一&#xff1a;简单了解 二&#xff1a;什么是DevOps 三&#xff1a;安装Jenkins 四&#xff1…

uniapp实现-审批流程效果

一、实现思路 需要要定义一个变量, 记录当前激活的步骤。通过数组的长度来循环数据&#xff0c;如果有就采用3元一次进行选择。 把循环里面的变量【name、status、time】, 全部替换为取出的那一项的值。然后继续下一次循环。 虚拟的数据都是请求来的, 组装为好渲染的格式。 二…

Diffusion Models/Score-based Generative Models背后的深度学习原理(5):伪似然和蒙特卡洛近似配分函数

Diffusion Models专栏文章汇总&#xff1a;入门与实战 前言&#xff1a;有不少订阅我专栏的读者问diffusion models很深奥读不懂&#xff0c;需要先看一些什么知识打下基础&#xff1f;虽然diffusion models是一个非常前沿的工作&#xff0c;但肯定不是凭空产生的&#xff0c;背…

FaceBook获取广告数据

1、访问 广告管理工具 确认自己登陆的账号下面能看到户。 ​ 2、使用 图谱Api探索工具 生成用户短期口令 ​ 3、get请求(或者浏览器直接打开)访问&#xff1a; https://graph.facebook.com/v19.0/me?fieldsid,name, email&access_token{上一步生成的口令} ​ 4、短期…

ChatGPT4.0 的优势、升级 4.0 为什么这么难以及如何进行升级?

前言 “ChatGPT4.0一个月多少人民币&#xff1f;” ”chatgpt4账号“ ”chatgpt4 价格“ “chatgpt4多少钱” 最近发现很多小伙伴很想知道关于ChatGPT4.0的事情&#xff0c;于是写了这篇帖子&#xff0c;帮大家分析一下。 一、ChatGPT4.0 的优势 &#xff08;PS&#xff1a;…