李沐深度学习记录4:11模型选择、欠拟合和过拟合

news2025/1/12 19:05:12

权重衰减从零开始实现

#高维线性回归
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

#整个流程是,1.生成标准数据集,包括训练数据和测试数据
#          2.定义线性模型训练
#           模型初始化(函数)、包含惩罚项的损失(函数)
#           定义epochs进行训练,每训练5轮评估一次模型在训练集和测试集的损失,画图显示
#           训练结束后分别查看并比较是否添加范数惩罚项损失对应的训练结果w的L2范数
#生成数据集
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5  #训练数据样本数20,测试样本数100,数据维度200,批量大小5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05  #生成w矩阵(200,1),w值0.01,偏置b为0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train) #生成训练数据集X(20,200),y(20,1),y=Xw+b+噪声,train_data接收返回的X,y
train_iter = d2l.load_array(train_data, batch_size)  #传入数据集和批量大小,构造训练数据迭代器
test_data = d2l.synthetic_data(true_w, true_b, n_test) #生成测试数据集
test_iter = d2l.load_array(test_data, batch_size, is_train=False)  #构造测试数据迭代器

#初始化模型参数
def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

#定义L2范数惩罚项
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2  #L2范数公式需要开平方根,但这里L2范数惩罚项是L2范数的平方,所以不需要开平方根了

#训练代码
def train(lambd):  #输入λ超参数
    w, b = init_params()  #初始化模型参数
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss  #net线性模型torch.matmul(X, w) + b;loss是均方误差
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):  #进行多次迭代训练
        for X, y in train_iter:  #每个epoch,取训练数据
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)  #loss计算加上了λ×范数惩罚项
            l.sum().backward()  #这里计算损失和,下面参数更新时会对梯度求平均再更新参数
            d2l.sgd([w, b], lr, batch_size)  #进行参数更新操作
        if (epoch + 1) % 5 == 0:  #每5次epoch训练,评估一次模型的训练损失和测试损失
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())  #训练结束后,计算w的L2范数(没有平方)

#λ为0,无正则化项,训练
train(lambd=0)
d2l.plt.show()

在这里插入图片描述

#λ为10,有正则化项,训练
train(lambd=5)
d2l.plt.show()

在这里插入图片描述

权重衰减的简洁实现

#权重衰减的简洁实现
def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))   #定义模型
    for param in net.parameters():   #初始化参数
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')  #计算loss,这里不包含正则项
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    #在参数优化部分,计算梯度时加入了权重衰减
    #所以是计算loss时没计算正则项,只是在计算梯度时加入了权重衰减吗?
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):   #训练100轮
        for X, y in train_iter:  #对于每轮,取数据训练
            trainer.zero_grad()   #梯度清零
            l = loss(net(X), y)  #计算loss
            l.mean().backward() #反向传播
            trainer.step()  #更新梯度
        if (epoch + 1) % 5 == 0:   #每5轮评估一次模型在测试集和训练集的损失
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())
#没有进行权重衰减
train_concise(0)

在这里插入图片描述

#进行权重衰减
train_concise(5)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1061952.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一种4g扫码付费通电控制器方案

之前开发了一款扫码付款通电控制器 功能:用户扫码付款后设备通电,开始倒计时,倒计时结束后设备断电,资金到账商家的商家助手里面,腾讯会收取千分之6手续费。 产品主要应用场景 本产品主要应用于各类无人值守或者自助…

【算法基础】基础算法(二)--(高精度、前缀和与差分)

一、高精度 当一个数很大,大到 int 无法存下时,我们可以考虑用数组来进行存储,即数组中一个位置存放一位数。 但是对于数组而言,一个数顺序存入数组后,对其相加减是很简单的。但是当需要进位时,还是很麻烦的…

(c语言)调试——习题

第一题 题目: 解析: 答案:C 栈溢出属于运行时错误,在错误分类时不能分作一类 第二题 题目 : 解析: 答案:A F5是调试不执行,ctrlF5是开始执行不调试 第三题 题目: …

做好微信CRM,这些功能你不可不知!

在当前的数字化时代,微信已成为我们日常生活中的重要元素,无论是社交交流、信息传递还是商务合作,微信都扮演着不可或缺的角色。为了更有效地管理微信资源并提高工作效率,很多组织和公司都选择引入微信CRM系统。那么,怎…

服装服饰小程序商城的作用是什么

服装绝对算是市场重要的组成部分,零售批发都有大量从业者,随着线下流量匮乏、经营困难重重,很多厂家商家选择线上经营,主要方式是直播、入驻第三方平台等,同时私域节奏加快及线上平台限制等,不少商家也是通…

Java架构师设计思想

目录 1 设计核心思想封装2 设计核心思想隔离2.1 隔离的好处3 设计思想由大到小,由粗到精,逐步细化3.1 由大到小3.2 由粗到精3.3 逐步细化4 设计思想迭代4.1 和设计思想由大到小,由粗到精,逐步细化区别5 总结1 设计核心思想封装 首先我们来看一下什么是封装。那封装呢也叫做…

NodeMCU ESP8266硬件开发板的熟悉

文章目录 硬件开发环境的熟悉基础介绍什么是 ESP8266 NodeMCU?NodeMCU芯片ESP12-E 模组开发板 ESP8266 版本引脚图Power GND I2CGPIOADCUARTSPIPWMControl 总结 硬件开发环境的熟悉 基础介绍 什么是 ESP8266 NodeMCU? ESP8266是乐鑫开发的一款低成本 …

Linux系统下xxx is not in the sudoers file解决方法

文章目录 遇到问题解决方法参考 遇到问题 服务器上新建用户,名为lishizheng,现在想给该用户添加sudo权限。 $ sudo lsof -i tcp:7890 [sudo] password for lishizheng: lishizheng is not in the sudoers file. This incident will be reported.解决…

算法:强连通分量(SCC) Tarjan算法

强连通分量&#xff0c;不能再加任何一个点了&#xff0c;再加一个点就不是强连通了 vector<int>e[N]; int dfn[N],low[N],tot; bool instk[N]; int scc[N],siz[N],cnt; void tarjan(int x){//入x时,盖戳,入栈dfn[x]low[x]tot;q.push(x);instk[x]true;for(auto y:e[x]){i…

mysql8.0.31 源码阅读

知识背景 问&#xff1a;说说构造函数有哪几种&#xff1f;分别有什么用&#xff1f; C中的构造函数可以分为4类&#xff1a;默认构造函数、初始化构造函数、拷贝构造函数、移动构造函数。 1. 默认构造函数和初始化构造函数&#xff08;在定义类的对象时&#xff0c;完成对象…

Redis-持久化机制

持久化机制介绍 RDBAOFRDB和AOF对比 RDB rdb的话是利用了写时复制技术&#xff0c;他是看时间间隔内key值的变化量&#xff0c;就比如20秒内如果有5个key改变过的话他就会创建一个fork子进程&#xff08;bgsave&#xff09;&#xff0c;通过这个子进程&#xff0c;将数据快照进…

QT商业播放器

QT商业播放器 总体架构图 架构优点&#xff1a;解耦&#xff0c;采用生产者消费者设计模式&#xff0c;各个线程各司其职&#xff0c;通过消息队列高效协作 这个项目是一个基于ijkplayer和ffplayer.c的QT商业播放器, 项目有5部分构成&#xff1a; 前端QT用户界面 后端是集成了…

视频二维码的制作方法,支持内容修改编辑

现在学生经常会需要使用音视频二维码&#xff0c;比如外出打开、才艺展示、课文背诵等等。那么如何制作一个可以长期使用的二维码呢&#xff1f;下面来给大家分享一个二维码制作&#xff08;免费在线二维码生成器-二维码在线制作-音视频二维码在线生成工具-机智熊二维码&#x…

快速了解Spring Cache

SpringCache是一个框架&#xff0c;实现了基于注解的缓存功能&#xff0c;只需要简单的加一个注解&#xff0c;就可以实现缓存功能。 SpringCache提供了一层抽象&#xff0c;底层可以切换不同的缓存实现。例如&#xff1a; EHChche Redis Caffeine 常用注解&#xff1a; Enabl…

JMETER自适应高分辨率的显示器

系列文章目录 历史文章 每天15分钟JMeter入门篇&#xff08;一&#xff09;&#xff1a;Hello JMeter 每天15分钟JMeter入门篇&#xff08;二&#xff09;&#xff1a;使用JMeter实现并发测试 每天15分钟JMeter入门篇&#xff08;三&#xff09;&#xff1a;认识JMeter的逻辑控…

UG\NX二次开发 获取所有子部件,封装两个函数

文章作者:里海 来源网站:《里海NX二次开发3000例专栏》 感谢粉丝订阅 感谢 凉夜ronin 订阅本专栏,非常感谢。 简介 UG\NX二次开发 获取所有子部件,封装两个函数 效果 获取非抑制的所有子部件 //获取非抑制的所有子部件 vector<tag_t> GetChildPart(tag_t partOcc) {…

MyBatisPlus(十)判空查询

说明 判空查询&#xff0c;对应SQL语句中的 IS NULL语句&#xff0c;查询对应字段为 NULL 的数据。 isNull /*** 查询用户列表&#xff0c; 查询条件&#xff1a;电子邮箱为 null 。*/Testvoid isNull() {LambdaQueryWrapper<User> wrapper new LambdaQueryWrapper<…

基于安卓android微信小程序的远景民宿预订小程序

运行环境 开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Ma…

W25Q128芯片手册精读

文章目录 前言1. 概述2. 特性3. 封装类型和引脚配置3.1 8焊盘WSON 8x6 mm3.2其他封装 4. 引脚描述4.1 片选4.2 串行数据输入输出4.3 写保护4.4 保持脚4.5 时钟 5. 块图6. 功能描述6.1 SPI功能6.1.1 标准SPI6.1.2 双通道SPI6.1.3 四通道SPI6.1.4 保持功能 6.2 写保护6.2.1 写保护…

【Golang】gin框架入门

文章目录 gin框架入门认识gingo流行的web框架gin介绍快速入门 路由RESTful API规范请求方法URI处理函数分组路由 请求参数GET请求参数POST请求参数路径参数文件参数 响应字符串方式JSON方式XML方式文件格式设置HTTP响应头重定向YAML方式 模板渲染基本使用多个模板渲染自定义模板…