深度学习 —— 个人学习笔记6(权重衰减)

news2024/11/16 2:20:22

声明

  本文章为个人学习使用,版面观感若有不适请谅解,文中知识仅代表个人观点,若出现错误,欢迎各位批评指正。

十三、权重衰减

  使用以下公式为例做演示:

y = 0.05 + ∑ i = 1 d 0.01 x i + ε w h e r e ε    ~    N ( 0 , 0.0 1 2 ) y = 0.05 + \sum_{i=1}^{d} 0.01x_i + \varepsilon \quad where \quad \varepsilon \; ~ \; N ( 0 , 0.01^2 ) y=0.05+i=1d0.01xi+εwhereεN(0,0.012)

  • 权重衰减的实现
import torch
from torch import nn
from d2l import torch as d2l
from IPython import display

def synthetic_data(w, b, num_examples):
    """生成 y = Xw + b + 噪声。"""
    X = torch.normal(0, 1, (num_examples, len(w))).cuda()                    # 均值为 0,方差为 1,有 num_examples 个样本,列数为 w 长度
    y = torch.matmul(X, w).cuda() + b                                        # y = Xw + b
    y += torch.normal(0, 0.01, y.shape).cuda()                               # 随机噪音
    return X, y.reshape((-1, 1))                                             # x,y作为列向量返回

class Animator:                                                                   # 定义一个在动画中绘制数据的实用程序类 Animator
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # Add multiple data points into the figure
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        # 通过以下两行代码实现了在PyCharm中显示动图
        d2l.plt.draw()
        d2l.plt.pause(interval=0.001)
        display.clear_output(wait=True)
        d2l.plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']


n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)).cuda() * 0.01, 0.05
train_data = synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

##############    权重衰减的实现    #############
def init_params():
    """ 初始化参数 """
    w = torch.normal(0, 1, size=(num_inputs, 1)).cuda()
    b = torch.zeros(1).cuda()
    w.requires_grad_(True)
    b.requires_grad_(True)
    return [w, b]

def l2_penalty(w):
    """ 定义 L2 范数惩罚 """
    return (torch.sum(w.pow(2)) / 2).cuda()

def train(lambd):
    flag_button = "使用"
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 150, 0.005
    animator = Animator(xlabel='epochs', ylabel='loss', yscale='log',
                        xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了 L2 范数惩罚项,、
            # 广播机制使 l2_penalty(w) 成为一个长度为 batch_size 的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    # print('w的L2范数是:', torch.norm(w).item())
    if lambd == 0:flag_button = "禁用"
    d2l.plt.title(f"{flag_button}权重衰减 (lambda = {lambd})\nw 的 L2 范数是:{torch.norm(w).item()}")
    d2l.plt.show()


train(lambd=0)

train(lambd=15)


  • 权重衰减的简洁实现
import torch
from torch import nn
from d2l import torch as d2l
from IPython import display

def synthetic_data(w, b, num_examples):
    """生成 y = Xw + b + 噪声。"""
    X = torch.normal(0, 1, (num_examples, len(w))).cuda()                    # 均值为 0,方差为 1,有 num_examples 个样本,列数为 w 长度
    y = torch.matmul(X, w).cuda() + b                                        # y = Xw + b
    y += torch.normal(0, 0.01, y.shape).cuda()                               # 随机噪音
    return X, y.reshape((-1, 1))                                             # x,y作为列向量返回

class Animator:                                                                   # 定义一个在动画中绘制数据的实用程序类 Animator
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # Add multiple data points into the figure
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        # 通过以下两行代码实现了在PyCharm中显示动图
        d2l.plt.draw()
        d2l.plt.pause(interval=0.001)
        display.clear_output(wait=True)
        d2l.plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']


n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)).cuda() * 0.01, 0.05
train_data = synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

##############    权重衰减的简洁实现    #############

def train_concise(wd):
    flag_button = "使用"
    net = nn.Sequential(nn.Linear(num_inputs, 1)).cuda()
    for param in net.parameters():
        param.data.normal_().cuda()
    loss = nn.MSELoss(reduction='none').cuda()
    num_epochs, lr = 150, 0.005
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = Animator(xlabel='epochs', ylabel='loss', yscale='log',
                        xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    # print('w的L2范数:', net[0].weight.norm().item())
    if wd == 0:flag_button = "禁用"
    d2l.plt.title(f"{flag_button}权重衰减 (lambda = {wd})\nw 的 L2 范数是:{net[0].weight.norm().item()}")
    d2l.plt.show()


train_concise(0)

train_concise(-2)  



  文中部分知识参考:B 站 —— 跟李沐学AI;百度百科

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942993.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

勇闯高龄“禁区”,四川眼科医院成功为95岁高龄老人实施泪道手术

一吹风就流泪、眼角总有擦不干净的分泌物……很多人以为这只是个滴眼药就能解决的小问题。其实不然,“不起眼”的疾病发展严重时可能还会需要手术治疗。 近日,四川眼科接诊了一位眼泪汪汪的耄耋老人张奶奶(化名),此次…

vue2 使用代码编辑器插件 vue-codemirror

vue 使用代码编辑器插件 vue-codemirror 之前用过一次,当时用的一知半解的,所以也没有成文,前几天又因为项目有需求,所以说有用了一次,当然,依旧是一知半解,但是还是稍微写一下子吧!…

学习测试10-4自动化 web自动化

网页资源 链接: https://pan.baidu.com/s/17XL2c2lkw_R6BD–VnOQqw?pwd43dr 提取码: 43dr 复制这段内容后打开百度网盘手机App,操作更方便哦 框架之间切换 driver.switch_to.frame("idframe1") # 父切子 参数用id和name# 子切子必须先转回父 driver.sw…

数据分析:微生物数据的荟萃分析框架

介绍 Meta-analysis of fecal metagenomes reveals global microbial signatures that are specific for colorectal cancer提供了一种荟萃分析的框架,它主要基于常用的Wilcoxon rank-sum test和Blocked Wilcoxon rank-sum test 方法计算显著性,再使用分…

STM32自己从零开始实操10:PCB全过程

一、PCB总体分布 分布主要参考有: 方便供电布线。方便布信号线。方便接口。人体工学。 以下只能让大家看到各个模块大致分布在板子的哪一块,只能说每个人画都有自己的理由,我的理由如下。 还有很多没有表达出来的东西,我也不知…

Python和MATLAB网络尺度结构和幂律度大型图生成式模型算法

🎯要点 🎯算法随机图模型数学概率 | 🎯图预期度序列数学定义 | 🎯生成具有任意指数的大型幂律网络,数学计算幂律指数和平均度 | 🎯随机图分析中巨型连接分量数学理论和推论 | 🎯生成式多层网络…

如何解决Windows系统目录权限问题

目录 前言1. 为什么会出现权限问题2. 修改文件权限的步骤2.1 确定目标文件2.2 右键属性设置2.3 更改所有者2.4 修改权限2.5 确认修改 3. 替换文件3.1 拷贝新的文件3.2 验证替换结果 结语 前言 在Windows系统中,时常需要往C盘系统目录下拷贝或者替换文件。然而&…

【Python系列】JSON 序列化性能对比分析

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【学术会议征稿】第五届计算机工程与智能通信国际研讨会(ISCEIC 2024)

第五届计算机工程与智能通信国际研讨会(ISCEIC 2024) 2024 5th International Symposium on Computer Engineering and Intelligent Communications (ISCEIC 2024) 第五届计算机工程与智能通信国际研讨会(ISCEIC 2024)将于2024年…

安全管理(EHS系统)是什么?化工企业如何进行安全管理?

化工企业一般会涉及到易燃易爆、有毒有害的原材料和产品,生产环境有高温高压、腐蚀性强等危险因素。一旦管理不善或操作失误,极易引发火灾、爆炸、中毒等严重事故,不仅有人身伤害,还会给企业带来巨大损失,甚至影响社会…

如何快速批量修改照片拍摄日期?一键批量搞定拍摄日期修改教程

在摄影爱好者、专业摄影师甚至普通用户中,照片不仅仅是视觉记录,它们还承载着时间和地点的印记。当需要调整大量照片的拍摄日期时,手动操作显然不是最高效的方法。幸运的是,现代文件管理工具如“简鹿文件批量重命名”软件提供了批…

数据隐私保护与区块链技术的结合:新兴趋势分析

在当今数字化时代,数据隐私保护成为了一个备受关注的重要话题。随着个人数据的不断生成和流通,如何有效保护用户的隐私成为了技术创新的一个重要方向。区块链技术作为一种去中心化、安全性高且可追溯的技术手段,正在逐渐成为解决数据隐私保护…

Android --- 广播

广播是什么? 一种相互通信,传递信息的机制,组件内、进程间(App之间) 如何使用广播? 组成部分 发送者-发送广播 与启动其他四大组件一样,广播发送也是使用intent发送。 设置action&#xff…

RoundCube搭建安装教程:服务器配置方法?

RoundCube搭建安装教程的疑问解析!怎么搭建邮件系统? RoundCube是一款开源的Web邮件客户端,具有现代化的用户界面和丰富的功能,可以通过浏览器访问邮件服务器。AokSend将详细介绍如何在服务器上配置和安装RoundCube,以…

JS语法学习

找到官方库,查看相应资料:(都可以切换为中文版本的) 可以在 JavaScript 的官方网站上查看最新的语法规范和文档。JavaScript 的官方网站是 developer.mozilla.orghttps://developer.mozilla.org/en-US/docs/Web/JavaScript。那里…

尚庭公寓开发笔记(一)

本篇文章讲的是p前五十节课 可以关注后续 传统的数据库设计流程 分为三个阶段:概念模型设计阶段 逻辑模型设计阶段 物理模型设计 阶段 为本项目设计数据库模型 地图的存储只需要保存经纬度就ok 本项目采用的是mysql数据库 所有表都使用的是innnodb存储引擎 我们使…

数据编织 VS 数据仓库 VS 数据湖

目录 1. 什么是数据编织?2. 数据编织的工作原理3. 代码示例4. 数据编织的优势5. 应用场景6. 数据编织 vs 数据仓库6.1 数据存储方式6.2 数据更新和实时性6.3 灵活性和可扩展性6.4 查询性能6.5 数据治理和一致性6.6 适用场景6.7 代码示例比较 7. 数据编织 vs 数据湖7.1 数据存储…

内网安全:IPC横向

IPC计划任务横向 IPC配合系统服务横向 前言: IPC是为了实现进程之间的通信而开放的管道。IPC可以通过验证用户名和密码来获取相应的权限。通过IPC可以与目标机器建立连接。 IPC计划任务横向 本次目标:通过机器192.168.11.40,横向控制机器192…

dependency-check-maven依赖漏洞扫描

引入插件依赖&#xff1a; <plugin><groupId>org.owasp</groupId><artifactId>dependency-check-maven</artifactId><version>7.0.4</version><configuration><autoUpdate>false</autoUpdate><dataDirectory&g…

SQL

SQL全称 Structured Query Language&#xff0c;结构化查询语言。操作关系型数据库的编程语言&#xff0c;定义了一套操作关系型数据库统一标准 。 SQL通用语法 SQL语句可以单行或多行书写&#xff0c;以分号结尾。SQL语句可以使用空格/缩进来增强语句的可读性。MySQL数据库的…