pytorch深度学习基础 7 (简单的线性拟合+检验模型在验证集上的效果)

news2024/11/15 12:33:14

我们之前做的目的都是评估训练的损失,训练的损失Loss告诉我们,我们的模型是否能够完全拟合训练集,也就是说我们的模型是否有足够的能力处理数据中的相关信息。但是我们之前都是评价训练的好坏,并没有引入验证集。接下来我们就需要用训练集训练模型然后在验证集中检验模型的好坏,因为如果模型的参数在训练集中表现很好,但是如果应用与实际场景中,必然有很多训练集中没有的数据,可能在这些数据中表现比较差,那对于整个模型来说也不是特别好,所以我们的模型不仅要在训练集中表现良好,更要在验证集(val)中表现良好才行。

一、为了评估模型性能

  • 避免过拟合 模型在训练集上表现良好,但在验证集上表现不佳可能意味着过拟合。通过观察训练集和验证集的损失,能及时发现并调整模型,避免模型过度学习训练数据中的噪声和特定模式2。
  • 选择最佳超参数 利用验证集的损失来比较不同超参数配置下模型的性能,从而选择最优的超参数,提升模型的泛化能力2。

二、更好地理解模型学习过程

  • 观察模型收敛情况 比较训练集和验证集损失的变化趋势,了解模型在学习过程中的表现。例如,训练集损失持续降低但验证集损失趋于收敛,可能提示模型过拟合或训练集与测试集不是独立同分布2。

三、提高模型的泛化能力

  • 调整模型结构和参数 根据训练集和验证集损失的差异,对模型的结构和参数进行优化,使模型能够更好地适应新的数据,提高泛化能力2。

四、确保模型的可靠性

  • 验证模型的稳定性 观察不同随机种子下训练集和验证集的损失变化,评估模型对数据划分的敏感性,确保模型的稳定性和可靠性1。

获取数据集

那我们如何获取验证集呢,首先我们需要获得足够多的原始数据,这些数据往往是通过实际场景中收集的,但是小编这里直接写了一个程序来模拟数据,代码就分享给大家。

t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0,
                    8.0, 3.0, -4.0, 6.0, 13.0, 21.0])
t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,
                    33.9, 21.8, 48.4, 60.4, 68.4])


# 按照一定的规律生成100组原始数据集
# 假设的线性关系参数(这里我们先用简单的均值和方差来估计,但更准确的方法是使用最小二乘法)
# 注意:这里只是为了演示,实际中应该使用统计方法来准确估计
slope_est = (t_c.mean() - t_c.min()) / (t_u.mean() - t_u.min())
intercept_est = t_c.mean() - slope_est * t_u.mean()

# 生成额外的数据点
np.random.seed(0)  # 为了可重复性设置随机种子
additional_u = np.linspace(t_u.min(), t_u.max() * 2, 90)  # 生成额外的u值
additional_c = slope_est * torch.tensor(additional_u, dtype=torch.float32) + intercept_est + 0.5 * torch.randn(
    90)  # 加上一些噪声

# 合并原始数据和额外数据
t_u = torch.cat((t_u, torch.tensor(additional_u, dtype=torch.float32)))
t_c = torch.cat((t_c, additional_c))

 感兴趣的可以把他们打印出来看一下,因为我设置的随机的种子,所以大家打印出来的数据是完全一样的

tensor([ 35.7000,  55.9000,  58.2000,  81.9000,  56.3000,  48.9000,
         33.9000,  21.8000,  48.4000,  60.4000,  68.4000,  21.8000,
         23.3955,  24.9910,  26.5865,  28.1820,  29.7775,  31.3730,
         32.9685,  34.5640,  36.1595,  37.7551,  39.3506,  40.9461,
         42.5416,  44.1371,  45.7326,  47.3281,  48.9236,  50.5191,
         52.1146,  53.7101,  55.3056,  56.9011,  58.4966,  60.0921,
         61.6876,  63.2831,  64.8787,  66.4742,  68.0697,  69.6652,
         71.2607,  72.8562,  74.4517,  76.0472,  77.6427,  79.2382,
         80.8337,  82.4292,  84.0247,  85.6202,  87.2157,  88.8112,
         90.4067,  92.0023,  93.5978,  95.1933,  96.7888,  98.3843,
         99.9798, 101.5753, 103.1708, 104.7663, 106.3618, 107.9573,
        109.5528, 111.1483, 112.7438, 114.3393, 115.9348, 117.5303,
        119.1258, 120.7214, 122.3169, 123.9124, 125.5079, 127.1034,
        128.6989, 130.2944, 131.8899, 133.4854, 135.0809, 136.6764,
        138.2719, 139.8674, 141.4629, 143.0584, 144.6539, 146.2494,
        147.8449, 149.4404, 151.0360, 152.6315, 154.2270, 155.8225,
        157.4180, 159.0135, 160.6090, 162.2045, 163.8000])
tensor([ 0.5000, 14.0000, 15.0000, 28.0000, 11.0000,  8.0000,  3.0000,
        -4.0000,  6.0000, 13.0000, 21.0000, -4.2347, -3.2014, -2.1179,
        -1.6428, -0.9509, -0.5199,  0.2784,  1.8311,  2.9875,  2.8588,
         4.4212,  4.5838,  5.1310,  6.8234,  5.9734,  7.1035,  9.4547,
         8.6025, 10.6773,  9.9846, 11.9821, 11.9866, 13.3628, 13.4712,
        14.0208, 14.8172, 16.4681, 17.4892, 18.3710, 18.8813, 18.5837,
        19.6088, 20.5467, 21.7026, 21.6737, 21.9981, 24.0391, 24.4194,
        25.7778, 25.5416, 26.1872, 27.3508, 27.3104, 29.4915, 29.1853,
        31.0119, 31.5497, 32.4764, 33.0286, 32.5419, 34.0226, 35.3561,
        36.8098, 37.1465, 37.9528, 38.3130, 38.8218, 40.1028, 40.7696,
        42.1289, 41.1947, 42.5819, 43.5023, 43.5928, 45.2544, 45.7708,
        47.4863, 48.2580, 48.3122, 49.1236, 50.2917, 51.2436, 50.8630,
        52.8491, 52.9410, 53.7065, 54.7200, 54.7298, 56.8450, 57.1763,
        56.5910, 58.7009, 59.2632, 59.9197, 61.0314, 61.5569, 61.7117,
        62.8750, 63.9156, 65.0431])

生成训练集与验证集

然后我们先获取t_u的行数,通过对行数加权来获取我们想要的验证集的个数,小编这里加权给的是0.2,然后对原始数据进行打乱,把100个原始数据分为2份,一份是80个数据的训练集,另一份是20个数据的验证集。

n_samples = t_u.shape[0]  # 这条代码的作用是获取数组t_u的行数,并将这个值赋给变量n_samples
n_val = int(0.2 * n_samples)
shuffled_indices = torch.randperm(n_samples)  # 返回一个长度为 n_samples 的随机排列的整数序列
train_indices = shuffled_indices[:-n_val]  # 训练集
val_indices = shuffled_indices[-n_val:]    # 验证集

接下来的训练和之前几节的训练是一样的,区别就是,在训练的过程中顺便用训练训练集更新出来的params参数对验证集进行损失计算,并把Loss的值打印出来,方便后期的分析。 

import numpy as np
import torch
import torch.optim as optim # 导入优化器模块
torch.set_printoptions(edgeitems=2, linewidth=75) # 设置打印格式


def model(t_u, w, b):
    return w * t_u + b

def loss_fn(t_p, t_c):
    squared_diffs = (t_p - t_c)**2
    return squared_diffs.mean()
# 初始化参数



t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0,
                    8.0, 3.0, -4.0, 6.0, 13.0, 21.0])
t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,
                    33.9, 21.8, 48.4, 60.4, 68.4])


# 按照一定的规律生成100组原始数据集
# 假设的线性关系参数(这里我们先用简单的均值和方差来估计,但更准确的方法是使用最小二乘法)
# 注意:这里只是为了演示,实际中应该使用统计方法来准确估计
slope_est = (t_c.mean() - t_c.min()) / (t_u.mean() - t_u.min())
intercept_est = t_c.mean() - slope_est * t_u.mean()

# 生成额外的数据点
np.random.seed(0)  # 为了可重复性设置随机种子
additional_u = np.linspace(t_u.min(), t_u.max() * 2, 90)  # 生成额外的u值
additional_c = slope_est * torch.tensor(additional_u, dtype=torch.float32) + intercept_est + 0.5 * torch.randn(
    90)  # 加上一些噪声

# 合并原始数据和额外数据
t_u = torch.cat((t_u, torch.tensor(additional_u, dtype=torch.float32)))
t_c = torch.cat((t_c, additional_c))
print(t_u)
print(t_c)
t_un = 0.1 * t_u # 归一化处理,防止梯度爆炸
n_samples = t_u.shape[0]  # 这条代码的作用是获取数组t_u的行数,并将这个值赋给变量n_samples
n_val = int(0.2 * n_samples)

shuffled_indices = torch.randperm(n_samples)  # 返回一个长度为 n_samples 的随机排列的整数序列
print(shuffled_indices)
# tensor([ 6,  3,  2,  0,  7,  4,  5,  8,  9, 10,  1])

train_indices = shuffled_indices[:-n_val]  # 训练集
# print(train_indices)
val_indices = shuffled_indices[-n_val:]    # 验证集
# print(val_indices)


#  随机把原始数据t_u,t_c分成验证集和训练集
train_indices, val_indices
train_t_u = t_u[train_indices]
train_t_c = t_c[train_indices]

val_t_u = t_u[val_indices]
val_t_c = t_c[val_indices]

# 归一化处理
train_t_un = 0.1 * train_t_u
val_t_un = 0.1 * val_t_u


def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u,
                  train_t_c, val_t_c):
    for epoch in range(1, n_epochs + 1):
        train_t_p = model(train_t_u, *params)
        train_loss = loss_fn(train_t_p, train_t_c)

        val_t_p = model(val_t_u, *params)
        val_loss = loss_fn(val_t_p, val_t_c)

        optimizer.zero_grad()
        train_loss.backward()  # <2>
        optimizer.step()

        if epoch <= 3 or epoch % 500 == 0:
            print(f"Epoch {epoch}, Training loss {train_loss.item():.4f},"
                  f" Validation loss {val_loss.item():.4f}")

    return params

params = torch.tensor([1.0, 0.0], requires_grad=True)
learning_rate = 1e-3
optimizer = optim.SGD([params], lr=learning_rate)

training_loop(
    n_epochs = 3000,
    optimizer = optimizer,
    params = params,
    train_t_u = train_t_un,
    val_t_u = val_t_un,
    train_t_c = train_t_c,
    val_t_c = val_t_c)



我们把epochs改为13000轮

 不难发现,训练集和验证集收敛的还是比较快的,Loss值也比较豪,效果还是比较好的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2071696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java基础——自学习使用(多态)

一、多态的定义 父类的引用指向子类的对象。 B继承A&#xff0c;A abnew B();——父类引用指向子类的对象。 二、创建对象了解多态的内部结构 &#xff08;1&#xff09;父类即A类对象的内存结构图 &#xff08;2&#xff09;子类即B类对象的内存结构图 由于B中重写了父类A中…

EazyDraw for Mac 矢量图绘制设计软件

Mac分享吧 文章目录 效果一、下载软件二、开始安装1、双击运行软件&#xff0c;将其从左侧拖入右侧文件夹中&#xff0c;等待安装完毕2、应用程序显示软件图标&#xff0c;表示安装成功 三、运行测试安装完成&#xff01;&#xff01;&#xff01; 效果 一、下载软件 下载软件…

SSRF和CSRF实战复现

文章目录 SSRFWeb-Hacking-Lab-master1、Centos未授权访问2、Ubuntu未授权访问3、Ubuntu传入公钥访问4、ssrf_redis_lab_pickle_redis_lab CSRF:windphp SSRF SSRF(Server-Side Request Forgery:服务器端请求伪造) 是一种由攻击者构造形成由服务端发起请求的一个安全漏洞。 f…

第三课《排序》

前言 排序是将一组数据&#xff0c;按照指定的顺序或要求来进行排列的过程。是数据结构相关课程和内容较为重要和核心的内容之一&#xff0c;常常作为考试题和面试题目来考察学生和面试者&#xff0c;因此熟练掌握经典的排序算法原理和代码实现是非常重要的 本文介绍了几大较为…

AJAX(5)——Promise

Promise Promise对象用于表示一个异步操作的最终完成或失败及其结果值 语法&#xff1a; //创建Promise对象const p new Promise((resolve, reject) > {//执行异步代码setTimeout(() > {// resolve(成功结果)reject(new Error(失败结果))}, 2000)})//获取结果p.then(r…

坚鹏讲人才第13期:个人数字化转型——个人与时代的共赢之选

坚鹏讲人才第13期&#xff1a;个人数字化转型——个人与时代的共赢之选 在这个日新月异的时代&#xff0c;数字化转型已经成为当今时代的必然趋势&#xff0c;它不仅改变了我们的生活方式&#xff0c;也正在改变着各行各业的运营模式。数字化时代&#xff0c;不仅需要数字化企…

网络udp及ipc内存共享

大字符串找小字符串 调试 1. 信号处理函数注册&#xff1a;•一旦使用 signal 函数注册了信号处理函数&#xff0c;该函数就会一直有效&#xff0c;直到程序结束或者显式地取消注册。2. 注册多次的影响&#xff1a;•如果多次注册同一信号的处理函数&#xff0c;最后一次注册的…

快9月了刚结束基础,武忠祥强化vs张宇18讲应该如何选择?

快9月了&#xff0c;最近有一部分同学刚结束基础&#xff0c;在后台提问&#xff1a;强化到底该学武忠祥还是张宇18讲&#xff1f;其实这个问题&#xff0c;如果你是6月份开始强化&#xff0c;很好回答&#xff0c;但是现在已经快9月份了&#xff0c;很多同学都开始做真题了&am…

代码随想录 刷题记录-16 贪心算法(1)贪心理论基础及习题

一、理论基础 什么是贪心 贪心的本质是选择每一阶段的局部最优&#xff0c;从而达到全局最优。 贪心的套路&#xff08;什么时候用贪心&#xff09; 贪心算法并没有固定的套路。 所以唯一的难点就是如何通过局部最优&#xff0c;推出整体最优。 靠自己手动模拟&#xff0c…

深度学习 回归问题

1. 梯度下降算法 深度学习中, 梯度下降算法是是一种很重要的算法. 梯度下降算法与求极值的方法非常类似, 其核心思想是求解 x ′ x x′, 使得 x ′ x x′ 在取 x ⋆ x^{\star} x⋆ 时, 可以使得 l o s s 函数 loss函数 loss函数 的值最小. 其中, 在求解 x ′ x x′ 的过…

罗德与施瓦茨RS、UPV 音频分析仪 250KHZ 双通道分析仪UPL

罗德与施瓦茨 UPV 音频分析仪的规格包括&#xff1a; 模拟 双通道分析仪&#xff1a;带宽高达 250 kHz 生成正弦波信号&#xff1a;单通道最高 185 kHz&#xff08;需要 B1&#xff09;和双通道最高 80 kHz FFT本底噪声&#xff1a;< -140dB 固有频率响应&#xff08;20 …

链动 2+1 模式小程序 AI 智能名片商城源码培训邀约策略研究

摘要&#xff1a;本文深入剖析链动 21 模式小程序 AI 智能名片商城源码的培训邀约策略&#xff0c;从该源码的价值出发&#xff0c;阐述邀约的重要性&#xff0c;并详细介绍具体的邀约策略&#xff0c;旨在为相关培训活动提供切实可行的指导&#xff0c;提高邀约成功率&#xf…

前端如何快速切换node版本:nvm

安装之前最好卸载计算机已经安装的node&#xff08;通过Windows菜单找到Node.js的卸载程序&#xff0c;运行卸载程序&#xff09;。下载nvm安装包&#xff1a;nvm安装地址。安装nvm&#xff0c;选择nvm安装根路径指定nodejs的安装路径打开命令行&#xff0c;输入nvm -v 可查看版…

Object.create的原型继承

● 首先我们来从这种方法来创建一个和之前一样计算年龄的方法 const PersonProto {cacleAge() {console.log(2038 - birthYear);} };const zhangsan Object.create(PersonProto); console.log(zhangsan);● 发现确实可以实现原型继承的特性 const PersonProto {cacleAge()…

odoo17 group col 属性

odoo17 group col 属性 以前版本&#xff0c;col4,在17中不能用了&#xff0c;或者方法变了 <record id"hetong.addfj_wizard" model"ir.ui.view"><field name"name">合同附件</field><field name"model">het…

免费的大模型插件llm.nvim

llm.nvim&#xff08;https://github.com/StubbornVegeta/llm.nvim&#xff09;是一款基于cloudflare的免费大模型插件&#xff0c;你可以像使用ChatGPT一样和它进行对话 在使用这款插件之前&#xff0c;你需要注册cloudflare&#xff0c;获取你的account和API key。你可以在这…

RCE - - 无字母数字远程命令执行

题目源码 <?php if(isset($_GET[code])){$code $_GET[code];if(strlen($code)>35){die("Long.");}if(preg_match("/[A-Za-z0-9_$]/",$code)){die("NO.");}eval($code); }else{highlight_file(__FILE__); } 分析 这道题 code 接 get 传…

【Qt】常用控件QProgressBar

常用控件QProgressBar 使用QProgressBar表示一个进度条&#xff01;&#xff01;&#xff01; QProgressBar的核心属性 属性说明 minimum 进度条最⼩值 maximum 进度条最⼤值 value 进度条当前值 alignment ⽂本在进度条中的对⻬⽅式. Qt::AlignLeft : 左对⻬Qt::Alig…

AJAX(4)——XMLHttpRequest

XMLHttpRequest 定义&#xff1a;XMLHttpRequest(XHR)对象用于与服务器交互。通过XMLHttpRequest可以在不刷新页面的情况下请求特定URL&#xff0c;获取数据。这允许网页在不影响用于操作的情况下&#xff0c;更新页面的局部内容。XMLHttpRequest在AJAX编程中被大量使用 关系…

第6章 B+树索引

目录 6.1 没有索引的查找 6.1.1 在一个页中的查找 6.1.2 在很多页中查找 6.2 索引 6.2.1 一个简单的索引方案 6.2.2 InnoDB中的索引方案 6.2.2.1 聚簇索引 6.2.2.2 二级索引 6.2.2.3 联合索引 6.2.3 InnoDB的B树索引的注意事项 6.2.3.1 根页面万年不动窝 6.2.3.2 内节…