【李沐】3.2线性回归从0开始实现

news2025/1/12 8:03:51
%matplotlib inline
import random
import torch
from d2l import torch as d2l

1、生成数据集:
看最后的效果,用正态分布弄了一些噪音
在这里插入图片描述
上面这个具体实现可以看书,又想了想还是上代码把:
在这里插入图片描述
按照上面生成噪声,其中最后那个代表服从正态分布的噪声

def synthetic_data(w, b, num_examples):  # 定义函数 synthetic_data,接受权重 w、偏差 b 和样本数量 num_examples 作为参数
    """生成 y = Xw + b + 噪声 的合成数据集"""
    
    # 生成一个形状为 (num_examples, len(w)) 的特征矩阵 X,其中的元素是从均值为 0、标准差为 1 的正态分布中随机采样得到
    X = torch.normal(0, 1, (num_examples, len(w)))
    
    # 计算目标值 y,通过将特征矩阵 X 与权重 w 相乘,然后加上偏差 b,模拟线性回归的预测过程
    y = torch.matmul(X, w) + b
    
    # 给目标值 y 添加一个小的随机噪声,以模拟真实数据中的噪声。噪声从均值为 0、标准差为 0.01 的正态分布中随机采样得到
    y += torch.normal(0, 0.01, y.shape)
    
    # 返回特征矩阵 X 和目标值 y(将目标值 y 重塑为列向量的形式)
    return X, y.reshape((-1, 1)
# 定义真实的权重 true_w 为 [2, -3.4]
true_w = torch.tensor([2, -3.4])

# 定义真实的偏差 true_b 为 4.2
true_b = 4.2

# 调用 synthetic_data 函数生成合成数据集,传入真实的权重 true_w、偏差 true_b 和样本数量 1000
# 这将返回特征矩阵 features 和目标值 labels
features, labels = synthetic_data(true_w, true_b, 1000)


2、读取数据集
注意一般情况下要打乱。
下面函数的作用是该函数接收批量⼤⼩、特征矩阵和标签向量作为输⼊,⽣成⼤⼩为batch_size的⼩批量。每个⼩批量包含⼀组特征和标签。

def data_iter(batch_size, features, labels):
    num_examples = len(features)  # 获取样本数量
    indices = list(range(num_examples))  # 创建一个样本索引列表,表示样本的顺序
    
    # 将样本索引列表随机打乱,以便随机读取样本,没有特定的顺序
    random.shuffle(indices)
    
    # 通过循环每次取出一个批次大小的样本
    for i in range(0, num_examples, batch_size):
        # 计算当前批次的样本索引范围,确保不超出总样本数量
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        
        # 通过索引获取对应的特征和标签,然后通过 yield 返回这个批次的数据
        # yield 使得函数可以作为迭代器使用,在每次迭代时产生一个新的批次数据
        yield features[batch_indices], labels[batch_indices]

3、初始化模型参数
第一步:前面两行代码,,我
们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重,并将偏置初始化为0。
计算梯度使用2.5节引入的自动微分

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

4、定义模型
这里注意b是一个标量和向量相加,咋办?
前面说过向量的广播机制,就相当于是加到每一个上面

def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b

5、定义损失函数
y.reshape(y_hat.shape))啥意思?
y_hat是真实值,这里的意思是弄成和y_hat相同的大小

def squared_loss(y_hat, y): #@save
"""均⽅损失"""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

6、优化算法
问:这里的参数是啥参数?params
更新完的参数不用返回吗?
为什么需要梯度清零?

def sgd(params, lr, batch_size):  # 定义函数 sgd,接受参数 params、学习率 lr 和批次大小 batch_size
    """小批量随机梯度下降"""
    
    with torch.no_grad():  # 使用 torch.no_grad() 来关闭梯度跟踪,以减少内存消耗
    
        for param in params:  # 遍历模型参数列表
            param -= lr * param.grad / batch_size  # 更新参数:参数 = 参数 - 学习率 * 参数梯度 / 批次大小
            param.grad.zero_()  # 清零参数的梯度,以便下一轮梯度计算

7、训练
问:反向传播是为了干啥?
是为了计算梯度,那梯度是啥呢
梯度是参数更快收敛的方向(就是向量)
优化方法是干啥的?
优化方法就是根据上面传过来的梯度,计算参数更新
所以,这几章看完后需要梳理深度学习的整个过程,以及每块有哪些方法,这些方法的特点和用那种方法更好
问(1)每个epoch训练多少数据?
整个训练集
(2)损失函数是啥?
损失函数是用来计算真实值域预测值之间的距离,当然是距离越小越好,可以拿均方误差想一下
(3)l.sum().backward()是啥意思?
看注释,补充:.backward() 方法用于执行自动求导,计算总的损失值对于模型参数的梯度。这将会构建计算图并沿着图的反向传播路径计算梯度。
(4)但是上面所说的梯度保存在哪里呢?
w.grad 和 b.grad 中
(5)但是sgd中也没有用到w.grad 啊?
用到了,param 可以是 w 或者 b,而 param.grad 则是相应参数的梯度。
(6)新问题:train_l = loss(net(features, w, b), labels)不是在前面已经计算过损失函数了吗?为啥在这里还需要计算?
前面计算损失函数是间断性的,目的是更新模型参数。
后面仍然计算的目的是根据更新完的参数对模型在整个训练集上与真实标签的差距做一个评估。

lr = 0.03  # 设置学习率为 0.03,控制每次参数更新的步幅

num_epochs = 3  # 设置训练的轮次(迭代次数)为 3,即遍历整个数据集的次数

net = linreg  # 定义模型 net,通常表示线性回归模型

loss = squared_loss  # 定义损失函数 loss,通常为均方损失函数,用于衡量预测值与真实值之间的差距
for epoch in range(num_epochs):  # 迭代 num_epochs 轮,进行训练

    for X, y in data_iter(batch_size, features, labels):  # 遍历数据集的每个批次
        
        l = loss(net(X, w, b), y)  # 计算当前批次的损失值 l,表示预测值与真实值之间的差距
        
        # 因为 l 的形状是 (batch_size, 1),而不是一个标量。将 l 中的所有元素加起来,
        # 并计算关于 [w, b] 的梯度
        l.sum().backward()
        
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数,执行随机梯度下降算法
        
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)  # 在整个训练集上计算损失值
        
        # 打印当前迭代轮次和训练损失值的均值
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

8、练习中的问题

  1. 如果我们将权重初始化为零,会发⽣什么。算法仍然有效吗?
    无效,为啥?因为,不同的X输入是相同的输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/901280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

韩顺平Linux 四十四--

四十四、rwx权限 权限的基本介绍 输入指令 ls -l 显示的内容如下 -rwxrw-r-- 1 root 1213 Feb 2 09:39 abc0-9位说明 第0位确定文件类型(d , - , l , c , b) l 是链接,相当于 windows 的快捷方式- 代表是文件是普通文件d 是目录,相…

【java毕业设计】基于ssm+mysql+jsp的社区生活超市管理系统设计与实现(程序源码)-社区生活超市管理系统

基于ssmmysqljsp的社区生活超市管理系统设计与实现(程序源码毕业论文) 大家好,今天给大家介绍基于ssmmysqljsp的社区生活超市管理系统设计与实现,本论文只截取部分文章重点,文章末尾附有本毕业设计完整源码及论文的获取…

webshell绕过

文章目录 webshell前置知识进阶绕过 webshell 前置知识 <?phpecho "A"^""; ?>运行结果 可以看到出来的结果是字符“&#xff01;”。 为什么会得到这个结果&#xff1f;是因为代码的“A”字符与“”字符产生了异或。 php中&#xff0c;两个变…

系统架构设计专业技能 · 系统工程与系统性能

系列文章目录 系统架构设计专业技能 网络技术&#xff08;三&#xff09; 系统架构设计专业技能 系统安全分析与设计&#xff08;四&#xff09;【系统架构设计师】 系统架构设计高级技能 软件架构设计&#xff08;一&#xff09;【系统架构设计师】 系统架构设计高级技能 …

7-10 最佳情侣身高差

分数 10 全屏浏览题目 切换布局 作者 陈越 单位 浙江大学 专家通过多组情侣研究数据发现&#xff0c;最佳的情侣身高差遵循着一个公式&#xff1a;&#xff08;女方的身高&#xff09;1.09 &#xff08;男方的身高&#xff09;。如果符合&#xff0c;你俩的身高差不管是牵手…

Shell脚本基础( 四: sed编辑器)

目录 1 简介 1.1 sed编辑器的工作流程 2 sed 2.1 基本用法 2.2 sed基本格式 2.2.1 sed支持正则表达式 2.2.2 匹配正则表达式 2.2.3 奇数偶数表示 2.2.4 -d选项删除 2.2.5 -i修改文件内容 2.2.6 -a 追加 2.3 搜索替代 2.4 变量 1 简介 sed是一种流编辑器&#xff0c;…

我能“C”——数据的存储

目录 1. 数据类型介绍 1.1 类型的基本归类&#xff1a; 2. 整形在内存中的存储 2.1 原码、反码、补码 2.2 大小端介绍 2.3 练习 3. 浮点型在内存中的存储 3.1 一个例子 3.2 浮点数存储规则 1. 数据类型介绍 char // 字符数据类型 short // 短整…

linux字符设备

目录 设计字符设备 文件系统调用系统IO的内核处理过程 硬件层原理 驱动层原理 文件系统层原理 设备号的组成与哈希表 Hash Table&#xff08;哈希表、散列表&#xff0c;数组和链表的混合使用&#xff09; 设备号管理 关键的数据结构&#xff1a;char_device_struct&a…

Python应用工具-Jupyter Notebook

工具简介 Jupyter Notebook是 基于 网页的用于交互计算的 应用程序&#xff0c;以网页的形式打开&#xff0c;可以在网页页面中直接编写代码和运行代码&#xff0c;代码的运行结果也会直接在代码块下 显示&#xff0c;文档是保存为后缀名为 . ipynb 的 JSON 格式文件。 操作指令…

学习笔记:Opencv实现限制对比度得自适应直方图均衡CLAHE

2023.8.19 为了完成深度学习的进阶&#xff0c;得学习学习传统算法拓展知识面&#xff0c;记录自己的学习心得 CLAHE百科&#xff1a; 一种限制对比度自适应直方图均衡化方法&#xff0c;采用了限制直方图分布的方法和加速的插值方法 clahe&#xff08;限制对比度自适应直方图…

AI搜索引擎助力科学家创新

开发者希望通过帮助科学家从大量文献中发现联系从而解放科学家&#xff0c;让他们专注于发现和创新。 图片来源&#xff1a;The Project Twins 对于专注于历史的研究者Mushtaq Bilal来说&#xff0c;他在未来科技中投入了大量时间。 Bilal在丹麦南部大学&#xff08; Universit…

畅享个性海报创作——探索免费开源的在线自动生成海报项目魅力

我们的生活越来越离不开各种创意和宣传&#xff0c;而其中一个常见的需求就是制作精美的海报。然而&#xff0c;对许多人来说&#xff0c;制作海报可能并不是一件轻松的事情&#xff0c;往往需要专业的设计技能或者花费不少时间去请人帮忙。今天了我给大家介绍一款开源的可私有…

SQL助你面大厂(连续N天登录)

在腾讯、网易或者一些游戏类大厂中&#xff0c;他们经常关注的就是用户上线人数以及天数&#xff0c;那么给我们一个数据库&#xff0c;我们怎么样才能快速的查询那个用户的连续N天登录&#xff1f; 那我们用案例来说明&#xff0c;再多的语言在现实面前总是那么苍白无力&…

mongodb 数据库管理(数据库、集合、文档)

目录 一、数据库操作 1、创建数据库 2、删除数据库 二、集合操作 1、创建集合 2、删除集合 三、文档操作 1、创建文档 2、 插入文档 3、查看文档 4、更新文档 1&#xff09;update() 方法 2&#xff09;replace() 方法 一、数据库操作 1、创建数据库 创建数据库…

HCIP——VLAN实验2

一.实验要求 1.PC1/3的接口均为access模式&#xff0c;且属于van2&#xff0c;在同一网段 2.PC2/4/5/6的IP地址在同一网段&#xff0c;与PC1/3不在同一网段 3.PC2可以访问4/5/6&#xff0c;PC4不能访问5/6&#xff0c;PC5不能访问PC6 4.所有PC通过DHCP获取ip地址&#xff0c;PC…

第 7 章 排序算法(1)

7.1排序算法的介绍 排序也称排序算法(Sort Algorithm)&#xff0c;排序是将一组数据&#xff0c;依指定的顺序进行排列的过程。 7.2排序的分类&#xff1a; 内部排序: 指将需要处理的所有数据都加载到**内部存储器(内存)**中进行排序。外部排序法&#xff1a; 数据量过大&am…

深入探索:Kali Linux 网络安全之旅

目录 前言 访问官方网站 导航到下载页面 启动后界面操作 前言 "Kali" 可能指的是 Kali Linux&#xff0c;它是一种基于 Debian 的 Linux 发行版&#xff0c;专门用于渗透测试、网络安全评估、数字取证和相关的安全任务。Kali Linux 旨在提供一系列用于测试网络和…

【数据结构】吃透单链表!!!(详细解析~)

目录 前言&#xff1a;一.顺序表的缺陷 && 介绍链表1.顺序表的缺陷2.介绍链表&#xff08;1&#xff09;链表的概念&#xff08;2&#xff09;链表的结构&#xff08;3&#xff09;链表的功能 二.单链表的实现1.创建节点的结构2.头文件函数的声明3.函数的实现&#xff…

一、docker及mysql基本语法

文章目录 一、docker相关命令二、mysql相关命令 一、docker相关命令 &#xff08;1&#xff09;拉取镜像&#xff1a;docker pull <镜像ID/image> &#xff08;2&#xff09;查看当前docker中的镜像&#xff1a;docker images &#xff08;3&#xff09;删除镜像&#x…

golang云原生项目之:etcd服务注册与发现

服务注册与发现&#xff1a;ETCD 1直接调包 kitex-contrib&#xff1a; 上面有实现的案例&#xff0c;直接cv。下面是具体的理解 2 相关概念 EtcdResolver: etcd resolver是一种DNS解析器&#xff0c;用于将域名转换为etcd集群中的具体地址&#xff0c;以便应用程序可以与et…