【动手学习深度学习--逐行代码解析合集】09权重衰减

news2024/10/6 16:18:08

【动手学习深度学习】逐行代码解析合集

09权重衰减


视频链接:动手学习深度学习–权重衰减
课程主页:https://courses.d2l.ai/zh-v2/
教材:https://zh-v2.d2l.ai/

0、准备工作

import matplotlib # 注意这个也要import一次
import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

'''
我们选择标签是关于输入的线性函数。 标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到d=200,并使用一个只包含20个样本的小训练集。
'''
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 生成权重w和偏置b
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# d2l.synthetic_data:生成y=Xw+b+噪声,返回x和y
train_data = d2l.synthetic_data(true_w, true_b, n_train)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
# d2l.load_array:按批量大小取出训练集or测试集
train_iter = d2l.load_array(train_data, batch_size)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

1、随机初始化模型参数

"====================1、随机初始化模型参数===================="
def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

2、定义L2范数惩罚

"====================2、定义L2范数惩罚===================="
# 实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。不包含lambd
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2
def l1_penalty(w):
    return torch.sum(torch.abs(w))

3. 定义训练代码实现

"====================3. 定义训练代码实现===================="
def train(lambd):
    w, b = init_params()
    # 匿名函数:输入X,输出 线性函数 和 损失函数
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    # 训练轮数、学习率
    num_epochs, lr = 100, 0.003
    # 绘图函数
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            # 反向传播梯度
            l.sum().backward()
            # 参数更新
            d2l.sgd([w, b], lr, batch_size)
        # 每5轮产生的数据加入图像中
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())

4. 忽略正则化直接训练

"====================4. 忽略正则化直接训练===================="
# 我们现在用lambd = 0禁用权重衰减后运行这个代码。
# 注意,这里训练误差有了减少,但测试误差没有减少, 这意味着出现了严重的过拟合。
train(lambd=0)

【输出】w的L2范数是: 13.232867240905762

在这里插入图片描述

5、使用权重衰减

"====================5. 使用权重衰减===================="
# 注意,在这里训练误差增大,但测试误差减小。 这正是我们期望从正则化中得到的效果。
train(lambd=3)

【输出】 w的L2范数是: 0.360396146774292

在这里插入图片描述
补充:使用L1范数效果更好一点

当lambd=3
【输出】 w的L2范数是: 0.07004229724407196
在这里插入图片描述

权重衰减简洁实现

简洁表达的过程大体上没什么不一样,只不过把lambd改成了wd
pytorch代码中将weight_decay表示正则,把wd传入该值

"权重衰减简洁实现"
def train_concise(wd):
    # net直接使用nn库中自带的linear,传入输入维度(num_input)和输出维度1
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    # 取出net中的参数,按照正态分布对参数进行赋值
    for param in net.parameters():
        param.data.normal_()
    # 损失函数
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减。偏置名称通常以“bias”结尾
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # d2l中的训练函数已经帮我们清空过梯度了,但是这里的优化器是用的nn中自带的sgd,不会帮我们自动清空,所以我们要先手动清零
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

train_concise(0)
d2l.plt.show()

【输出】w的L2范数: 13.809814453125

在这里插入图片描述

train_concise(3)
d2l.plt.show()

【输出】w的L2范数: 0.35754090547561646

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/727100.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Wordpress的mysql迁库遇到问题

在我们迁移库的时候经常会出现如下问题: 5.7日期默认0000-00-00 00:00:00 设置错误。 MySQL默认设置中不支持日期datetime格式下的0000-00-00 00:00:00。 解决方法如下: select sql_mode 来查看对应内容 ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO…

vue打包后,生成的dist文件出现浏览器缓存问题——技能提升

最近同事问我,打包后的项目放在服务器上后,在终端打开时,必须清除浏览器缓存也就是ctrlF5才可以。 我打包后查看dist/index.html文件 发现没有css和js文件都有不同版本号的标识,按道理来说,是不会出现这个缓存的问题…

在Chrome谷歌浏览器中执行JavaScript的方法

在Chrome谷歌浏览器中如何执行JavaScript?在Chrome 浏览器中可以通过按下 F12 按钮或者右击页面,选择"检查"来开启开发者工具。 也可以在右上角菜单栏选择 "更多工具"》"开发者工具" 来开启: 1、Console 窗口调…

基于机器学习的情感分析

1基于机器学习 是指选取情感词作为特征词,将文本矩阵化,利用logistic Regression, 朴素贝叶斯(Naive Bayes),支持向量机(SVM)等方法进行分类。最终分类效果取决于训练文本的选择以及正确的情感标注。 在训练过程&#…

骨传导耳机音质怎么样,盘点在音质方面表现不错的五款骨传导耳机

骨传导耳机凭借不入耳就能轻松听音乐的特点,被越来越多人所认识,相比传统的入耳式耳机,骨传导耳机拥有更多的可玩性,比如说跑步、游泳、健身都可以佩戴骨传导耳机,即使长时间佩戴也不会出现不适感,也不会出…

LinearAlgebraMIT_3_InverseMatrix

x.1 矩阵乘法 矩阵乘法的常用运算规则有五种,如下是一种,是最简单的矩阵乘法,用一行乘以一列,假设A是mxn的矩阵,B是nxp的矩阵,则最终得到mxp的矩阵。 在矩阵A和向量a乘法中,我们已经习惯性地将…

Sumifs函数(excel)

SUMIFS 函数是一个数学与三角函数,用于计算其满足多个条件的全部参数的总量。excel如何使用Sumifs函数? 工具/原料 联想ThinkPad X1 windows7 WPS office2021 方法/步骤 首先运行office软件,打开一份表格,今天我们要计算以“…

SDN-OpenDaylight与Mininet的原理、安装、使用

一、前言 本文将介绍OpenDaylight与Mininet的原理并介绍他们的安装及简单的使用,本实验的环境为Liunx Ubuntu 16.04,已成功安装OVS,但没有安装Mininet。 二、原理 (一)OpenDaylight OpenDaylight是一个软件定义网络&…

【抖音小游戏】 Unity制作抖音小游戏方案 最新完整详细教程来袭【持续更新】

前言 【抖音小游戏】 Unity制作抖音小游戏方案 最新完整详细教程来袭【持续更新】一、相关准备工作1.1 用到的相关网址1.2 注册字节开发者后台账号 二、相关集成工作2.1 下载需要的集成资源2.2 安装StarkSDK和starksdk-unity-tools工具包2.3 搭建测试场景 三、构建发布3.1 发布…

2.5 DNS 应用 -- 1. DNS 概述

2.5 DNS 应用 -- 1. DNS 概述 DNS:Domain Name SystemDNS分布式层次式数据库DNS根域名服务器TLD和权威域名解析服务器本地域名解析服务器 DNS 查询迭代查询递归查询 DNS记录缓存和更新 DNS:Domain Name System Internet上主机/路由器的识别问题 IP地址域…

基于matlab处理 RGB-D图像数据以构建室内环境地图并估计相机的轨迹(附源码)

一、前言 视觉同步定位和映射 (vSLAM) 是指计算摄像机相对于周围环境的位置和方向,同时映射环境的过程。 您可以使用单眼摄像头执行 vSLAM。但是,深度无法准确计算,估计的轨迹未知,并且随着时间的推移而漂…

红帽恪守对开源的承诺:对 git.centos.org 变更的回应

导读红帽上周宣布了限制源代码访问性的政策,称其企业发行版 RHEL (Red Hat Enterprise Linux) 相关源码仅通过 CentOS Stream 公开,付费客户和合作伙伴可通过 Red Hat Customer Portal 访问到源代码。 此举引发了巨大争议,红帽甚至被指责 “背…

大数据开发环境-Hbase

1.启动之前需要确保hadoop启动 # 查看 Hadoop 是否已经正常启动 : start-all.sh jps 2.启动Hbase

运输层:TCP可靠传输

1.运输层:TCP可靠传输 笔记来源: 湖科大教书匠:TCP可靠传输 声明:该学习笔记来自湖科大教书匠,笔记仅做学习参考 TCP实现可靠传输的方式:以字节为单位的滑动窗口 发送方将31 ~ 41号报文段发送 假设32 ~ 3…

React04-Hooks 详解

一、Hooks 1. Hooks 简介 Hooks,可以翻译成钩子。 在类组件中钩子函数就是生命周期函数,Hooks 主要用在函数组件。 在 react 中定义组件有2种方式:class 定义的类组件和 function 定义的函数组件。 在类组件中,钩子函数可以给…

学生适合用什么台灯护眼?暑假适合孩子学习的台灯分享

又要临近暑假了,孩子们又要开始整天围着手机、电视、平板等等,想想就感觉到头疼。也有些家长趁着暑假期间给孩子报一下兴趣班,培养一下孩子的技能和情操。不过也要注意孩子的视力健康,不少孩子就是因为在暑假期间没有注意用眼习惯…

Camtasia 2023.1.0免费版电脑视频录制和剪辑软件

Camtasia Studio是一套专业的屏幕录像软件,同时包含Camtasia 录像器、Camtasia Studio(编辑器)、Camtasia 菜单制作器、Camtasia 剧场、Camtasia 播放器和Screencast的内置功能。Camtasia 是一款专门捕捉屏幕影音的工具软件。它能在任何模式下…

企业金蝶云星空服务器数据库中了locked勒索病毒如何应对

近日,很多企业的金蝶云星空财务账套被locked勒索病毒攻击,财务系统内的许多重要数据被加密,无法正常打开,计算机内的所有文件的扩展名全部都变成了.locked后缀勒索病毒,导致服务器数据库被锁定。这种情况的出现与企业的…

云原生之深入解析K8S Istio Gateway服务的架构分析与实战操作

一、概述 Istio 提供一种简单的方式来为已部署的服务建立网络,该网络具有负载均衡、服务间认证、监控、网关等功能,而不需要对服务的代码做任何改动。 istio 适用于容器或虚拟机环境(特别是 k8s),兼容异构架构&#x…

6.1 计算机网络应用模式

6.1 计算机网络应用模式 计算机网络应用模式与计算机网络的发展密切相关,大体可以分为三个阶段 以大型机为中心的应用模式(mainframe-centric) 该应用模式也称为分时共享(time-sharing)模式,也就是面向终端…