5. 线性回归的从零开始实现

news2024/11/22 23:57:07

1.生成数据集

# num_examples 表示样本数量,也就是房屋数量
# w是权重向量
def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    # X是一个从独立的正态分布中抽取的随机数的张量,正态分布的平均值为0、标准差为1
    # 第三项则表示矩阵X的形状:有num_examples行,len(w)列
    # 也就是说,对于矩阵X,每一行表示一个样本,每一列表示一个影响房屋价格的因素,有两个因素
    X = torch.normal(0, 1, (num_examples, len(w)))
    
    # torch.matual即时执行矩阵相乘
    y = torch.matmul(X, w) + b
    
    # 给y加入均值为0,标准差为0.01的噪音,噪音的形状和y的长度是一样的
    y += torch.normal(0, 0.01, y.shape)
    # reshape(-1,1),1表示把y向量转成1列,另一个维度,即行,是-1,那么就自动计算
    return X, y.reshape((-1, 1))

# 真实的w和b
true_w = torch.tensor([2, -3.4])
true_b = 4.2

# 调用函数,得到特征X和标签y
features, labels = synthetic_data(true_w, true_b, 1000)

在这里,关于reshape的两个的两个参数是用来调整矩阵的形状的,第一个表示行,第二个表示列,当为-1时,表示根据另一维度的固定值自动计算。如reshape(2,3)表示调整为2行3列,reshape(3,-1)表示调整为三行,列数根据行数自动计算

由传入函数的最后一个参数可知:传入1000表示1000个样本,那么矩阵X又或者是最后得到的features 就是一个1000*2的矩阵,2表示影响房价的两个因素。

接下来,打印输出第0个样本的数据以及对应标签:

print('features:',features[0],'\nlabel:',labels[0])

在这里插入图片描述

可以看出:第0个样本是一个长为2的向量,也就是和w的长度一样,而对应的标签是一个标量。其他样本同理。

接下来,可以把特征的第一列和label画出来:

d2l.set_figsize()
d2l.plt.scatter(features[:,1].detach().numpy(),
               labels.detach().numpy(),1);

把特征第1列(从第0列算起),也就是第2个特征拎出来和labels以图像形式呈现,scatter函数的最后一个参数1的意思是:绘制点直径的大小。

并且因为features是一个1000*2的矩阵,取出第一列,则代表取出1000个数,那么下图中有1000个点。

在这里插入图片描述

可以看到是有相关性的,并且是负相关,因为w中第二个标量是-3.4.同理,如果把X矩阵的第0列和labels画出来,结果应该是正相关的,因为w中第一个标量是2

2. 读取数据集

接下来,定义一个data_iter函数,该函数接受批量大小,特征矩阵和标签向量作为输入,生成大小为batch_size的小批量,目的是为了:从全样本集中抽取部分样本,以用来训练后面创建的模型

def data_iter(batch_size,features,labels):
	# features是1000*2的矩阵,len()这一函数表示其第一维度(行)的长度,也就是样本数量
	# 因为一行代表一个样本
    num_examples = len(features) # 样本数量
    indices = list(range(num_examples)) # 生成对于每个样本的index索引,从0~999
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices) # 这一函数是为了把indices这一个列表打乱,才能为后面实现随机读取
    for i in range(0,num_examples,batch_size):
    # 从样本0开始,直到最后一个样本,每一次跳batch_size的大小
        batch_indices = torch.tensor(indices[i:min(i+
                                                  batch_size,num_examples)])
    # 为什么要加上min这一函数呢?
    # 因为假设1000个样本,如果batch_size=15,那么最后一次取的时候是不足15的
    # 也就是最后一个batch不够,直接用剩余的即可。
        yield features[batch_indices],labels[batch_indices]
        
batch_size = 10

for x,y in data_iter(batch_size,features,labels):
    print(X,'\n',y)
    break

打印的结果如下:
在这里插入图片描述
解释:在调用了data_iter()函数的for循环中,有一个break,即:只打印了第一个batch,就退出循环了。也正好对应了batch_size=10,而X的第一个batch的样本,都是随机的,每一个样本都是一个二维的行向量。

3. 初始化参数模型

# w是一个长为2的列向量,将其随之初始化成均值为0,标准差为0.01的正态分布,并且需要计算梯度
w = torch.normal(0,0.01,size=(2,1),requires_grad=True)

# 对于偏差b而言,直接初始化为0,“1”表示标量,需要对偏差进行更新,所以requires_grad也等于True
b = torch.zeros(1,requires_grad=True)

4. 定义线性回归模型

将模型的输入和参数同模型的输出关联起来。要计算线性模型的输出, 我们只需计算输入特征X和模型权重w的矩阵-向量乘法后加上偏置b。 注意,上面的Xw是一个向量,而b是一个标量

回想一下广播机制,当我们用一个向量加一个标量时,标量会被加到向量的每个分量上。

def linreg(X,w,b):
    '''线性回归模型'''
    return torch.matmul(X,w)+b

5. 定义损失函数

因为需要计算损失函数的梯度,所以我们应该先定义损失函数。 在实现中,我们需要将真实值y的形状转换为和预测值y_hat的形状相同.

def squared_loss(y_hat,y):
    '''均方误差'''
    return (y_hat-y.reshape(y_hat.shape))**2 / 2

6. 定义优化算法

在每一步中,使用从数据集中随机抽取的一个小批量,然后根据参数计算损失的梯度。

接下来,朝着减少损失的方向更新我们的参数(w和b)。 下面的函数实现小批量随机梯度下降更新。

该函数接受模型参数集合、学习速率和批量大小作为输入。每 一步更新的大小由学习速率lr决定。 因为我们计算的损失是一个批量样本的总和,所以我们用批量大小(batch_size) 来规范化步长,这样步长大小就不会取决于我们对批量大小的选择。

# params是list,包括w和b
def sgd(params,lr,batch_size):
    '''小批量随机梯度下降'''
    # 禁用梯度计算以加快计算速度
    # 当确保下文不用backward()函数计算梯度时可以用,用于禁用梯度计算功能,以加快计算速度
    with torch.no_grad(): # 不需要计算梯度,更新的时候不需要参与梯度计算
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_() # 梯度清零
            # pytorch会不断地累积变量的梯度,所以每更新一次参数,就要让其对应的梯度清零

7. 训练

现在我们已经准备好了模型训练所有需要的要素,可以实现主要的训练过程部分了。 理解这段代码至关重要,因为从事深度学习后, 你会一遍又一遍地看到几乎相同的训练过程。

在每次迭代中,我们读取一小批量训练样本,并通过我们的模型来获得一组预测。 计算完损失后,我们开始反向传播,存储每个参数的梯度。 最后,我们调用优化算法sgd来更新模型参数。

lr = 0.03 # 学习率为0.03
num_epochs = 3 # 对整个数据集进行三次扫描
net = linreg # 线性回归
loss = squared_loss

for epoch in range(num_epochs):
    for X,y in data_iter(batch_size,features,labels): # 拿出一个batch_size的X和y
        # net(X,w,b) 把X,w,b放入net中,用来做预测
        # loss(net(X,w,b),y) :把预测的y和真实的y来做损失
        l = loss(net(X,w,b),y) 
        # 得到的l是一个长为batch_size的列向量
        
        # l中的所有元素被加到一起,并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w,b],lr,batch_size) # 使用参数的梯度更新参数
        
    with torch.no_grad():
        train_1 = loss(net(features,w,b),labels)
        print(f'epoch{epoch+1},loss{float(train_1.mean()):f}')

在这里插入图片描述

因为我们使用的是自己合成的数据集,所以我们知道真正的参数是什么。 因此,我们可以通过比较真实参数和通过训练学到的参数来评估训练的成功程度。 事实上,真实参数和通过训练学到的参数确实非常接近。

print(f'w的估计误差:{true_w-w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b-b}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/53914.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

双十二怎么入手,几款性能好物分享

过完了双十一,接下来就应该面临今年最后一个大优惠力度的双十二了,而且双十二的时间刚好靠近在过年,所以在这期间相信很多人购买的物品是更加偏向于家居用品方面,那么就不能够错过本篇文章了,本篇文章将为你们分享一些…

[附源码]计算机毕业设计springboot松林小区疫情防控信息管理系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

sonarqube安装

Sonarqube安装文档 1. 环境准备 参照官方文档Prerequisites and Overview | SonarQube Docs 安装符合sonarqube版本的JDK和数据库 目前服务器上JDK版本为11.0.2 sonarqube版本为9.1.0 postgresql版本为13.7 2. 安装JDK11.0.2 将openjdk-11.0.2_linux-x64_bin.tar.gz放到/usr/…

spring详解(一)

今天我们来学习一个新的框架spring!!! spring是什么呢? spring是2003年兴起的,是一款轻量级、非侵入式的IOC和AOP的一站式的java开发框架,为简化企业即开发而生。 轻量级:spring核心功能的jar包不大 非侵入式:我们的业务代码不需要继承或…

linux parted 方式挂盘,支持大于4T盘扩容

此 内容与之前的linux mbr转gpt格式有些重复,但为了便于查询,还是单抽出相关内容,进行操作: 1.查询要挂的有磁盘路径, 输入 parted -l 。 2 . 进入parted对/dev/vdb盘的交互方式:输入: parted /dev/vdb&am…

Spring Cloud Gateway 网关组件及搭建实例

Spring Cloud Gateway 是 Spring Cloud 团队基于 Spring 5.0、Spring Boot 2.0 和 Project Reactor 等技术开发的高性能 API 网关组件。Spring Cloud Gateway 旨在提供一种简单而有效的途径来发送 API,并为它们提供横切关注点,例如:安全性&am…

Linux 线程控制 —— 线程取消 pthread_cancel

线程退出pthread_exit只能终止当前线程,也就是哪个线程调用了pthread_exit,哪个线程就会退出;但是线程取消pthread_cancel ,不光可以终止自己,还可以终止其他线程。 》自己终止自己,没问题! 》…

Android ViewPager2 + TabLayout + BottomNavigationView

Android ViewPager2 TabLayout BottomNavigationView 实际案例 本篇主要介绍一下 ViewPager2 TabLayout BottomNavigationView 的结合操作 概述 相信大家都看过今日头条的的样式 如下: 顶部有这种tab 并且是可以滑动的, 这就是本篇所介绍的 ViewPager2 TabLayout 的组合…

【C++】C++实战项目机房预约管理系统

前言 这是C总结性练习,主要以一个综合案例对以前学过的知识进行复习巩固,为以后编程打下基础。 1. 机房预约系统需求 1.1 系统简介 学校有几个规格不同的机房, 由于使用时经常出现“撞车”现象,现开发一套机房预约系统&#x…

[附源码]JAVA毕业设计会议室租赁管理系统(系统+LW)

[附源码]JAVA毕业设计会议室租赁管理系统(系统LW) 目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技…

EMC原理 传导(共模 差模) 辐射(近场 远场) 详解

第一章、EMC概念介绍 EMC(electromagnetic compatibility)作为产品的一个特性,译为电磁兼容性;如果作为一门学科,则译为电磁兼容。它包括两个概念:EMI和EMS。EMI(electromagneticinterference) 电磁干扰&a…

从Github上整理下来的《Java面试神技》

该文档曾在Github上线6天,共收获55Kstar的Java面试神技(这赞数,质量多高就不用我多说了吧)非常全面,包涵Java基础、Java集合、JavaWeb、Java异常、OOP、IO与NIO、反射、注解、多线程、JVM、MySQL、MongoDB、Spring全家…

通俗易懂帮你理清操作系统(Operator System)

文章目录概念(是什么)设计OS的目的(为什么)如何理解 "管理"(怎么办)总结系统调用和库函数概念概念(是什么) 任何计算机系统都包含一个基本的程序集合,称为操作…

照亮无尽前沿之路:华为正成为科技灯塔的守护者

20世纪中叶,著名科学家、工程师,被誉为“信息时代之父”的范内瓦布什,在《科学:无尽的前沿》中讨论了科学战略与科学基础设施对科技发展的重要性。其中提出,人类科技发展已经从以个人、学校为单位,来到了以…

【能效管理】关于学校预付费水电系统云平台应用分析介绍

概述 安科瑞 李亚俊 壹捌柒贰壹零玖捌柒伍柒 当下智慧校园、平安校园的建设越来越普及,作为智慧校园建设的重要一环,学生宿舍的用电预付费和用电管理措施是必不可少的。学生宿舍预付费电控系统可以解决使用传统电表人工抄表费时费力,不方便统…

[附源码]JAVA毕业设计基于MVC框架的在线书店设计(系统+LW)

[附源码]JAVA毕业设计基于MVC框架的在线书店设计(系统LW) 目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 …

光源基础(2)——光的强度、波长、颜色合成与互补关系

光源基本参数 光的度量 辐射能和光能 以辐射形式发射、传播或接收的能量称为辐射能,其计量单位为焦耳(J)。光能是光通量在可见光范围内对时间的积分,其计量单位为流明秒(lms)。 辐射通量和光通量 辐射通量或辐射功率是以辐射形式发射、传播或接收的功率…

【servelt原理_4_Http协议】

Http协议 1.认识url url被称为统一资源定位符&#xff0c;用来表示从互联网上得到的资源位置和访问这些资源的方法。 他的表示方法一般为&#xff1a; <协议>://<主机>:<端口>/<路径>如下我们启动一个servlet程序&#xff0c;来看一下我们的url表示 …

Yolov5 基本环境(cpu)搭建记录

Yolov5 基本环境(cpu)搭建记录 软件包&#xff1a; 1.anaconda&#xff08;https://www.anaconda.com/&#xff09; 2.pycharm&#xff08;https://www.jetbrains.com/pycharm/&#xff09; 3.torchvision-0.11.0cpu-cp37-cp37m-win_amd64.whl&#xff08;https://download.py…

Node.js学习上(67th)

1、基础内容 1、命令行 1、CMD命令 1、dir&#xff1a;列出当前目录下的所有文件 2、cd 目录名&#xff1a;进入指定目录 3、md 目录名&#xff1a;新建文件夹 4、rd 目录名&#xff1a;删除文件夹 5、a.txt&#xff1a;直接打开当前目录下的文件 2、目录 1、.&#xff1a…