02|李沐动手学深度学习v2(笔记)

news2024/11/27 0:33:48

基础优化算法

导航

  • 基础优化算法
  • 梯度下降
    • 1.1 小批量随机梯度下降
    • 1.2 小结
  • 线性回归实现
    • 1. 处理数据
      • 1.3 生成大小为batch_size的小批量
    • 2. 处理模型
    • 3. 模型评估
    • 4. 训练过程


梯度下降

  • 针对我们的模型没有显示解。(生活中很少能有完全符合的线性模型,大多数模型都是没有显示解的)
  • 梯度:使函数的值增加最快的方向。
  • 负梯度:使函数的值下降最快的方向。
  • 学习率η:η沿着方向一次走多远。(η读作:yita)
  • (-η * 倒数)为函数下降最快的地方。那么w0+(-η * 倒数)为w1的位置。
    在这里插入图片描述

1.1 小批量随机梯度下降

  • 每次求梯度时,需要对整个损失函数求导。这个损失函数是对我们所有样本的平均损失。意味着:求一次梯度,需要对整个样本算一遍。这个开销很大,很贵的事情。
    在这里插入图片描述
    在这里插入图片描述

1.2 小结

在这里插入图片描述

线性回归实现

1. 处理数据

  • 如果没有d2l包,需要进入cmd,以管理员身份运行。输入:pip install -U d2l -i https://mirrors.aliyun.com/pypi/simple/下载即可。
  • 如果报:ModuleNotFoundError: No module named ‘torchvision’ ,直接在jupyter notebook中输入:pip install torchvision -i https://mirrors.aliyun.com/pypi/simple/
    在这里插入图片描述
    在这里插入图片描述

这段代码定义了一个名为synthetic_data的函数,用于生成合成数据。

该函数接受三个参数:

  • w:一个一维张量(向量),表示模型的权重。
  • b:一个标量,表示模型的偏置项。
  • num_examples:整数,表示要生成的数据样本数量。

函数的主要步骤如下:

  1. 使用torch.normal(0, 1, (num_examples, len(w)))生成形状为(num_examples, len(w))的服从标准正态分布的随机张量X,其中均值为0,标准差为1。
  2. 使用矩阵乘法运算符torch.matmul(X, w)X与权重w相乘,然后加上偏置项b,得到预测值y
  3. 使用torch.normal(0, 0.01, y.shape)生成形状与y相同的服从标准正态分布的随机噪声,并将其加到y上,以模拟真实数据的噪声。
  4. 最后,使用y.reshape((-1, 1))y转换为形状为(-1, 1)的二维张量,其中-1表示根据其他维度的大小自动计算该维度的大小。简单来说,reshape((-1,1))就是将数组转换成只有一列,行数不确定的二维数组。

函数返回生成的合成数据X和转换后的标签y

在这里插入图片描述
这段代码是使用 d2l 库绘制散点图的示例。

d2l.set_figsize() 用于设置图形的大小,可以指定宽度和高度。

d2l.plt.scatter(features[:,1].detach().numpy(), labels.detach().numpy(), 1) 这行代码绘制散点图,其中:

  • features[:,1] 表示从特征矩阵中取出第二列数据作为 x 坐标;
  • labels 表示标签数据;
  • 1 表示散点的半径为 1。

detach() 方法将张量从计算图中分离出来,返回一个在内存中独立的张量,这样可以避免在绘图时修改原始数据。numpy() 方法将张量转换为 NumPy 数组,以便在绘图函数中使用。

1.3 生成大小为batch_size的小批量

在这里插入图片描述
这段代码定义了一个名为 data_iter 的函数,用于生成批量训练数据的迭代器。

函数接受三个参数:batch_sizefeatureslabels。其中,batch_size 表示每个批次中的样本数量,features 是特征矩阵,labels 是标签向量。

首先,函数计算了样本总数 num_examples,然后创建了一个包含所有样本索引的列表 indices。接着,使用 random.shuffle() 方法将索引随机打乱。

接下来,函数使用一个循环来生成批次数据。在每次循环中,它从打乱后的索引列表中取出 batch_size 个索引,并将这些索引转换为一个张量 batch_indices。然后,函数使用 yield 语句返回当前批次对应的特征矩阵和标签向量。

由于使用了 yield 语句,该函数是一个生成器函数,可以在循环中逐个生成批次数据,而不需要一次性将所有数据加载到内存中。这样可以有效地减少内存占用,提高训练效率。

2. 处理模型

在这里插入图片描述
这段代码定义了两个 PyTorch 张量 wb,用于神经网络中的线性回归模型。

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True) 这行代码创建了一个形状为 (2,1) 的张量 w,其中的元素是从均值为 0、标准差为 0.01 的正态分布中随机采样得到的。requires_grad=True 表示需要计算该张量的梯度,以便在反向传播过程中更新参数。

b = torch.zeros(1, requires_grad=True) 这行代码创建了一个形状为 (1,) 的张量 b,其中的元素全部初始化为 0。requires_grad=True 表示需要计算该张量的梯度,以便在反向传播过程中更新参数。

在神经网络中,wb 通常分别表示线性回归模型的权重和偏置项。通过不断迭代优化算法(如随机梯度下降),可以更新 wb 的值,使得模型的预测结果越来越接近真实值。

3. 模型评估

在这里插入图片描述
这段代码定义了一个名为 sgd 的函数,用于执行小批量随机梯度下降算法。

函数接受三个参数:paramslrbatch_size。其中,params 是一个包含模型参数的列表或张量,lr 是学习率,batch_size 是每个批次中的样本数量。

函数使用 torch.no_grad() 上下文管理器来禁用梯度计算,以避免在反向传播过程中占用过多的内存。

接下来,函数使用一个循环遍历 params 中的每个参数。对于每个参数,它首先使用当前批次的梯度(即 param.grad)除以批次大小 batch_size,然后乘以学习率 lr,并从原始参数值中减去这个值,以更新参数。最后,使用 param.grad.zero_() 将该参数的梯度清零,以便在下一个迭代中使用新的梯度值。

总之,这段代码实现了一种简单的随机梯度下降算法,用于训练神经网络模型。

4. 训练过程

在这里插入图片描述
被挡住的代码为: print(f'epoch{epoch+1}, loss{float(train_l.mean()):f}')
这段代码是一个训练神经网络模型的完整流程,包括前向传播、计算损失、反向传播和参数更新。

首先,定义了一个 for 循环,用于迭代多个 epoch。在每个 epoch 中,使用 data_iter() 函数生成一批数据 X,y,其中 batch_size 是每个批次中的样本数量,features 是特征矩阵,labels 是标签向量。

接下来,对于每个数据点 X,使用 net(X,w,b) 进行前向传播,得到预测值 。然后,计算预测值与真实标签 y 之间的损失 loss(net(X,w,b),y)

接着,调用 l.sum().backward() 对损失进行反向传播,计算出每个参数的梯度。然后,使用随机梯度下降算法 sgd([w,b], lr, batch_size) 更新参数 wb

在每个 epoch 结束后,使用 with torch.no_grad(): 上下文管理器将梯度计算关闭,以避免在输出训练过程中占用过多的内存。然后,使用 train_l = loss(net(features, w, b), labels) 计算当前模型在测试集上的损失,并打印出当前 epoch 和平均损失。

总之,这段代码实现了一个标准的神经网络训练过程,通过不断迭代优化算法来提高模型的性能。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/969281.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用户中心笔记-leovany

1. 安装 官方地址:https://pro.ant.design/zh-CN/docs/getting-started 1.1 Mac系统 1.1.1 安装yarn 安装yarn brew install yarn查看版本 brew -v 1.1.2 安装node // 安装node brew install node // 关联 brew unlink node && brew link node // 查看版…

信息系统安全运维模型 课堂记录

声明 本文是学习 信息系统安全运维管理指南. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 范围 本标准描述了信息系统安全运维管理体系,给出了安全运维策略、安全运维组织、安全运维规程和安全运维支撑系统等方面相关活动的目的、要求和…

【项目 计网9】4.25 IO多路复用简介 4.26select API介绍 4.27 select代码编写

文章目录 4.25 IO多路复用(I/O多路转接)简介4.26select API介绍4.27 select代码编写客户端程序select程序select的缺点 4.25 IO多路复用(I/O多路转接)简介 输入输出:以内存为主体 读写:以程序为主体 程序要…

2023-09-03 LeetCode每日一题(消灭怪物的最大数量)

2023-09-03每日一题 一、题目编号 1921. 消灭怪物的最大数量二、题目链接 点击跳转到题目位置 三、题目描述 你正在玩一款电子游戏,在游戏中你需要保护城市免受怪物侵袭。给你一个 下标从 0 开始 且长度为 n 的整数数组 dist ,其中 dist[i] 是第 i …

从一到无穷大 #12 Planet-Scale In-Memory Time Series Database, Is it really Monarch?

本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 本作品 (李兆龙 博文, 由 李兆龙 创作),由 李兆龙 确认,转载请注明版权。 文章目录 引言约束优势数据模型写路径查询路径Field Hints Index可靠性 其他总结 引言 Monarc…

Thymeleaf常见属性

参考文档 thymeleaf 语法——th:text默认值、字符串连接、th:attr、th:href 传参、th:include传参、th:inline 内联、th:each循环、th:with、th:if_猎人在吃肉的博客-CSDN博客 代码演示 Controller public class TestController {AutowiredMenuService menuService;GetMapp…

基于多设计模式下的同步异步日志系统

基于多设计模式下的同步&异步日志系统 代码链接:https://github.com/Janonez/Log_System 1. 项目介绍 本项目主要实现一个日志系统, 其主要支持以下功能: 支持多级别日志消息支持同步日志和异步日志支持可靠写入日志到标准输出、文件…

uni-app之android原生插件开发

一 插件简介 1.1 当HBuilderX中提供的能力无法满足App功能需求,需要通过使用Andorid/iOS原生开发实现时,可使用App离线SDK开发原生插件来扩展原生能力。 1.2 插件类型有两种,Module模式和Component模式 Module模式:能力扩展&…

S32K324芯片学习笔记

文章目录 Core and architectureDMASystem and power managementMemory and memory interfacesClocksSecurity and integrity安全与完整性Safety ISO26262Analog、Timers功能框图内存mapflash Signal MultiplexingPort和MSCR寄存器的mapping Core and architecture 两个Arm Co…

数学建模:Yalmip求解线性与非线性优化问题

🔆 文章首发于我的个人博客:欢迎大佬们来逛逛 线性优化 使用 Yalmip 求解线性规划最优值: m i n { − x 1 − 2 x 2 3 x 3 } x 1 x 2 ⩾ 3 x 2 x 3 ⩾ 3 x 1 x 3 4 0 ≤ x 1 , x 2 , x 3 ≤ 2 \begin{gathered}min\{-x_1-2x_23x_3\} \…

networkX-01-基础

文章目录 创建一个图1. 节点方式1 :一次添加一个节点方式2:从list中添加节点方式3:添加节点时附加节点属性字典方式4:将一个图中的节点合并到另外一个图中 2. 边方式1:一次添加一条边方式2:列表&#xff08…

23062网络编程day2

1. TCP的服务器 客户端的代码 服务器 #include <myhead.h>#define ERR_MSG(msg) do{\fprintf(stderr,"__%d__:",__LINE__);\perror(msg);\ }while(0)#define PORT 8888#define IP "192.168.114.104"int main(int argc, const char *argv[]) {//创建…

大数据技术原理与应用学习笔记第1章

黄金组合访问地址&#xff1a;http://dblab.xmu.edu.cn/post/7553/ 1.《大数据技术原理与应用》教材 官网&#xff1a;http://dblab.xmu.edu.cn/post/bigdata/ 2.大数据软件安装和编程实践指南 官网林子雨编著《大数据技术原理与应用》教材配套大数据软件安装和编程实践指…

Windows 操作系统下 Python 及其模块的管理

Python 是一款解释型语言&#xff0c;理论上一个.py文件可以当成一个稍微复杂一些的字符串指令集本文不涉及jupyter,VS,VScode,Pycharm 等集成开发环境&#xff0c;这不是我们这篇文章所关心的东西 这篇文章面向的是Python 的初学者 最近没有写太多长文章&#xff0c;多写几篇&…

8、暴力递归

前缀树 一个字符串类型的数组arr1,另一个字符串类型的数组arr2。arr2中有哪些字符,是arr1中出现的?请打印。arr2中有哪些字符,是作为arr1中某个字符串前缀出现的?请打印。arr2中有哪些字符,是作为arr1中某个字符串前缀出现的?请打印 arr2中出现次数最大的前缀 public…

LabVIEW开发超导体电流特性的测量系统

LabVIEW开发超导体电流特性的测量系统 超导体的临界电流密度Jc不断增加&#xff0c;目前超导线已达到150MA/cm2因此&#xff0c;由于电流能力增强&#xff0c;超导体被认为应用于电力系统&#xff0c;例如传输电缆、超导磁体和超导磁储能。由于Jc是此类应用的重要值&#xff0…

STM32F4X RNG随机数发生器

STM32F4X RNG随机数发生器 随机数的作用STM32F4X 随机数发生器RNG控制寄存器RNG状态寄存器RNG数据寄存器RNG数据步骤RNG例程 随机数的作用 随机数顾名思义就是随机产生的数字&#xff0c;这种数字最大的特点就是其不确定性&#xff0c;你不知道它下一次产生的数字是什么。随机…

差分数组/前缀和

文章目录 1094. 拼车1109. 航班预定统计303. 区域和检索 - 数组不可变560. 和为K的子数组523. 连续的子数组的和 1094. 拼车 class Solution {public boolean carPooling(int[][] trips, int capacity) {int[] diff new int[1001]; // 记录每个站点改变的人数&#xff0c;比如…

c语言---指针

指针 前言 记录一个数据对象在内存中的存储位置&#xff0c;需要两个信息&#xff1a; 1、数据对象的首地址。 2、数据对象占用存储空间大小 基础数据类型所占内存空间大小&#xff08;字节&#xff09;&#xff0c;一个字节代表8个二进制位 char 1 short 2 int 4 lon…

Java中的网络编程------基于Socket的TCP编程和基于UDP的网络编程,netstat指令

Socket 在Java中&#xff0c;Socket是一种用于网络通信的编程接口&#xff0c;它允许不同计算机之间的程序进行数据交换和通信。Socket使得网络应用程序能够通过TCP或UDP协议在不同主机之间建立连接、发送数据和接收数据。以下是Socket的基本介绍&#xff1a; Socket类型&…