动手学深度学习(Pytorch版)代码实践 -深度学习基础-10权重衰减

news2025/3/13 11:16:23

10权重衰减

"""
正则化是处理过拟合的常用方法:在训练集的损失函数中加入惩罚项,以降低学习到的模型的复杂度。
保持模型简单的一个特别的选择是使用L2惩罚的权重衰减。这会导致学习算法更新步骤中的权重衰减。
"""

import torch
from torch import nn
from d2l import torch as d2l
import liliPytorch as lp

n_train, n_test, num_input, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_input,1)) * 0.01, 0.05

train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)

test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)


#初始化模型参数
def init_params():
    w = torch.normal(0,1,size=(num_input,1), requires_grad=True)
    b = torch.zeros(1,requires_grad=True)
    return [w,b]

#定义L2范数惩罚
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

def l1_penalty(w):
    return torch.sum(torch.abs(w))

# 定义模型
def linreg(X, w, b):
    """线性回归模型"""
    return torch.matmul(X, w) + b

# 定义损失函数
def squared_loss(y_hat, y):
    """均方损失函数"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

# 定义优化函数
def sgd(params, lr, batch_size):
    """小批量随机梯度下降"""
    # 更新参数时不需要计算梯度
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size  # 参数更新
            param.grad.zero_()  # 梯度清零

#定义训练代码实现
def train(lambd):
    w, b = init_params()
    net, loss = lambda X: linreg(X, w, b), squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())


#忽略正则化直接训练¶
# train(lambd=0)
#w的L2范数是: 14.630496978759766

# 使用权重衰减
# train(lambd=3)
# d2l.plt.show() 


#权重衰减-简洁实现
def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_input, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = lp.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

train_concise(0)
d2l.plt.show() 
# w的L2范数是: 0.33992505073547363

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1850875.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

揭秘古代手术工具与技术:从中国起源的医疗奇迹

在人类历史的长河中,医学的发展一直是推动社会进步的重要力量。而手术作为医学的一个重要分支,其发展历程同样充满了传奇色彩。今天,我们将带您走进古代手术的世界,揭秘那些令人惊叹的手术工具和技术。 这把手术刀出土于河北西村遗…

sqlmap使用以及GUI安装

下载 GUI版地址: GitHub - honmashironeko/sqlmap-gui: 基于官版本 SQLMAP 进行人工汉化,并提供GUI界面及多个自动化脚本 GUI使用 可以点击.bat启动 如果点击.bat启动不了就在这里打开cmd,输入对应的.bat来启动 linux安装 地址:sqlmap: automatic SQL injection…

express+vue在线im实现【三】

往期内容 expressvue在线im实现【一】 expressvue在线im实现【二】 本期示例 本期总结 支持各种类型的文件上传,常见文件类型图片,音频,视频等,上传时同步获取音频与视频的时长,以及使用上传文件的缓存路径来作为vi…

天马学航——智慧教务系统(移动端)开发日志六

天马学航——智慧教务系统(移动端)开发日志六 日志摘要:统一身份认证设计,修复了选课信息错乱的问题 界面设计 实现思路 使用 Java 和 Jedis 完成实现: 步骤一:添加 Jedis 依赖 首先需要在项目中添加 Jedis 依赖,…

已解决VirtualMachineError: 虚拟机错误的正确解决方法,亲测有效!!!

已解决VirtualMachineError: 虚拟机错误的正确解决方法,亲测有效!!! 目录 问题分析 报错原因 解决思路 解决方法 分析错误日志 优化代码 内存泄漏排查 优化递归调用 调整JVM参数 使用监控工具 增加物理内存或升级硬件…

芝麻清单助力提升学习工作效率 专注时间完成有效的待办事项

芝麻清单助力提升学习&工作效率 专注时间完成有效的工作。今天我们给大家带来一个专注清单,一个更高效的学习和工作的方法! 我们都知道,专注做一个事情,会有效的提升效率,让事情更高效的完成。如果是学习的话&…

java基于ssm+jsp 母婴用品网站

1管理员功能模块 管理员登录,管理员通过输入用户名、密码等信息进行系统登录,如图1所示。 图1管理员登录界面图 管理员登录进入母婴用品网站可以查看主页、个人中心、用户管理、商品分类管理、商品信息管理、留言板管理、成长交流、系统管理、订单管理、…

Springboot应用的信创适配-补充

Springboot应用的信创适配-CSDN博客 因为篇幅限制,这里补全Spring信创适配、数据库信创适配、Redis信创适配、消息队列信创适配等四个章节。 Springboot应用的信创适配 Springboot应用的信创适配,如上图所示需要适配的很多,从硬件、操作系统、…

vue3 computed与watch,watchEffect比较

相同点 都是要根据一个或多个响应式数据进行监听 不同点 computed 如要return回来一个新的响应式值,且这个值不允许直接修改,想要修改的话可以设置set函数,在函数里面去修改所依赖的响应式数据,然后计算属性值会基于其响应式依…

多功能投票系统(ThinkPHP+FastAdmin+Uniapp)

让决策更高效,更民主🌟 ​基于ThinkPHPFastAdminUniapp开发的多功能系统,支持图文投票、自定义选手报名内容、自定义主题色、礼物功能(高级授权)、弹幕功能(高级授权)、会员发布、支持数据库私有化部署,Uniapp提供全部无加密源码…

ffmpeg音视频开发从入门到精通——ffmpeg实现音频抽取

文章目录 FFmpeg 实现音频流抽取1. 包含FFmpeg头文件与命名空间声明2. 主函数与参数处理3. 打开输入文件4. 获取文件信息5. 查找音频流6. 分配输出文件上下文7. 猜测输出文件格式8. 创建新的音频流9. 打开输出文件10. 写入文件头信息11. 读取并写入音频数据12. 写入文件尾部信息…

vue中的状态管理

第1部分:引言 状态管理是应用中数据流动和变更的核心机制。在Vue应用中,状态管理不仅涉及到组件间的数据共享,还包括了数据的持久化、异步操作的处理等复杂场景。良好的状态管理策略可以提高应用的响应速度,降低组件间的耦合度&a…

经典游戏案例:植物大战僵尸

学习目标:植物大战僵尸核心玩法实现 游戏画面 项目结构目录 部分核心代码 using System; using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.SceneManagement; using Random UnityEngine.Random;public enum Z…

(2024)豆瓣电影详情页内容爬虫详解和源码

&#xff08;2024&#xff09;豆瓣电影详情页内容爬虫详解和源码 这是一个Python爬虫程序&#xff0c;用于抓取豆瓣电影详情页面如https://movie.douban.com/subject/1291560/的数据。它首先发送GET请求&#xff0c;使用PyQuery解析DOM&#xff0c;然后根据<br>标签分割H…

C语言第17篇:预处理详解

1、预定义符号 C语言设置了一些预定义符号&#xff0c;可以直接使用。预定义符号也是在预处理期间处理的。 __FILE__ //进行编译的源文件 __LINE__ //文件当前的行号 __DATE__ //文件被编译的日期 __TIME__ //文件被编译的时间 __STDC__ //如果编译器遵循ANSI…

LaTeX中添加矩阵分块虚线并设置虚线疏密

对于大型矩阵&#xff0c;有时需要添加分块虚线。 方法为使用arydshln宏包&#xff0c;然后在array环境中设置虚线。需要注意的是&#xff0c;使用矩阵环境需要搭配amsmath宏包使用&#xff0c;且需放在amsmath宏包之后。即导言区设置为 \usepackage{amsmath} \usepackage{ary…

人人讲视频如何下载

一、工具准备 1.VLC media player 2.谷歌浏览器 二、视频下载 1.打开人人讲网页&#xff0c;需要下载的视频 谷歌浏览器打开调试窗口 搜索m3u8链接 拷贝到VLCplayer打开网络串流方式打开测试是否能正常播放 2.下载视频 能正常播放后&#xff0c;切换播放为转换选择mp4格式…

【CPP】归并排序

目录 1.归并排序简介代码分析归并的非递归形式 1.归并排序 归并排序&#xff08;MERGE-SORT&#xff09; 是建立在归并操作上的一种有效的排序算法,该算法是采用分治法&#xff08;Divide andConquer&#xff09;的一个非常典型的应用。 将已有序的子序列合并&#xff0c;得到…

扩散模型 GLIDE:35 亿参数的情况下优于 120 亿参数的 DALL-E 模型

节前&#xff0c;我们星球组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、参加社招和校招面试的同学。 针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 合集&#x…

com域名注册多少钱

COM域名注册价格视具体注册商而定&#xff0c;不同的注册商可能会有不同的收费标准。一般来说&#xff0c;COM域名注册价格在10美元到20美元之间&#xff0c;可根据不同的需求选择注册时间的长短&#xff0c;从1年到10年等不同时间段的注册费用也不同。以下是关于COM域名注册价…