线形回归与小批量梯度下降实例

news2025/1/14 16:39:27

1、准备数据集

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

#########################################################################
#################准备若干个随机的x和y#####################################
#########################################################################
np.random.seed(100)     #使用random.seed,设置一个固定的随机种子,
data_size = 150         # 数据集大小
x_range = 5             # x的范围
iteration_count = 100   # 迭代次数

# np.random.rand 是 NumPy 库中的一个函数,用于生成一个给定形状的数组,
# 数组中的元素是从一个均匀分布的样本中抽取的,这个均匀分布是在半开区间 [0, 1) 上。
# 这意味着产生的随机数将大于等于0且小于1。
                                                                                 
#随机生成data_size个横坐标x,范围在0到x_range之间
x=x_range * np.random.rand(data_size,1)

#生成带有噪音y数据,基本分布在y=2x+6的附近
y=2*x + 6 + np.random.randn(data_size,1)*0.3

plt.scatter(x,y,marker='x',color='green')

#########################################################################
#################将训练数据转为张量#######################################
#########################################################################
#将训练数据转为张量
tensorX = torch.from_numpy(x).float()
tensorY = torch.from_numpy(y).float()

#使用TensorDataset,将tensorX和tensorY组成训练集
dataset = TensorDataset(tensorX,tensorY)

#使用DataLoader,构造随机的小批量数据
dataloader=DataLoader(dataset,
                      batch_size = 20, #每一个小批量的数据规模是20
                      shuffle =True )  #随机打乱数据的顺序
print("dataloader len =%d" %(len(dataloader)))

for index,(data,label) in enumerate(dataloader):
    print("index=%d num = %d"%(index,len(data)))

2、线性回归模型的训练思路

2.1 初始化参数

设置是模型参数:权重w和偏置b

初始化为随机值

设置 requires_grad=True,PyTorch 将记录这些张量的操作历史,用于后续的自动求导

2.2 循环训练(epoch

epoch 变量用于控制整个训练过程的迭代轮数

在机器学习和深度学习中,“epoch” 是一个常用的术语,指的是在整个数据集上完整地运行一次(即正向传播和反向传播)训练算法的过程。

定义:一个 epoch 是指训练过程中,训练集中每个样本都被使用过一次来更新模型的权重。
训练过程:在训练一个模型时,通常会将数据集分成多个批次(batches)。每个批次包含一定数量的样本。一个 epoch 完成意味着所有批次都已经过模型处理。
迭代与epoch:在一个 epoch 内,模型可能会多次迭代,每次迭代处理一个批次的数据。因此,一个 epoch 包含多个迭代(iterations)。
目的:通过多个 epochs 的训练,模型可以逐渐学习数据集中的模式,从而提高其性能。
数量:训练一个模型所需的 epochs 数量取决于多种因素,包括数据集的大小、模型的复杂度以及问题的难度。有时可能只需要几个 epochs,而有时可能需要数百甚至数千个 epochs。
监控:在训练过程中,通常会监控每个 epoch 的性能指标(如损失函数的值或准确率),以评估模型的学习进度。
过拟合与欠拟合:如果训练过多的 epochs,模型可能会过拟合(即模型学习到了数据中的噪声而非潜在的模式),而训练不足的 epochs 则可能导致欠拟合(即模型未能捕捉到数据中的关键模式)。

2.3 数据加载

内层循环通过 dataloader 遍历训练数据集的小批量数据。dataloader 是一个数据加载器,通常由 DataLoader 类创建,用于批量加载数据。

2.4 前向传播

假设 tensorX 是当前批次的数据

tensorY 是对应的真实标签

使用当前参数 w 和 b 计算预测值 h=w*tensorX +b。

2.5 计算损失

计算预测值 h 和真实值 tensorY 之间的均方误差(MSE),并保存到 loss

loss = torch.mean((h - tensorY) ** 2)

2.6 反向传播

调用 loss.backward() 进行反向传播,计算损失关于参数 w 和 b 的梯度

设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导

2.7 更新参数

使用梯度下降算法更新参数 w 和 b。学习率设置为0.01

w.data -= 0.01 * w.grad.data

b.data -= 0.01 * b.grad.data

沿着当前小批量计算的得到的梯度(导数)更新w和b

如果导数为0,则w、b保存不变

2.8 梯度清零

在每次迭代后,需要清空参数的梯度信息,以便下一次迭代计算

3、线性回归模型的实现

# 待送代的参数为w和b
w = torch.randn(1,requires_grad=True)
b = torch.randn(1,requires_grad=True)

#进入模型的循环迭代
for epoch in range(1,iteration_count):#代表了整个训练数据集的迭代轮数
    # 在一个迭代轮次中,以小批量的方式,使用dataloader对数据
    # batch_index表示当前遍历的批次
    # data和label表示这个批次的训练数据和标记
    for batch_index,(data, label)in enumerate(dataloader):
        h = tensorX * w + b #计算当前直线的预测值,保存到h
        
        #计算预测值h和真实值y之间的均方误差,保存到loss中
        loss=torch.mean((h-tensorY)**2)
        #计算代价1oss关于参数w和b的偏导数,设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导
        loss.backward()
        
        #进行梯度下降,沿着梯度的反方向,更新w和b的值
        #沿着当前小批量计算的得到的梯度(导数)更新w和b
        #如果导数为0(Δw,Δb为0),则w、b保存不变
        w.data -=0.01 * w.grad.data
        b.data -=0.01 * b.grad.data
        
        print("epoch(%d) batch(%d) lossΔw,Δb,w,b, = %.3lf,%.3lf,%.3lf,%.3lf,%.3lf," %(epoch,batch_index,loss.item(),w.grad.data,b.grad.data,w.data,b.data))
        
        #清空张量w和b中的梯度信息,为下一次迭代做准备
        w.grad.zero_()
        b.grad.zero_()
        
        #每次迭代,都打印当前迭代的轮数epoch
        #数据的批次batch idx和loss损失值
        
        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2276562.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

P3884 [JLOI2009] 二叉树问题

题目描述: 如下图所示的一棵二叉树的深度、宽度及结点间距离分别为: - 深度:4 - 宽度:4 - 结点 8 和 6 之间的距离:8 - 结点 7 和 6 之间的距离:3 其中宽度表示二叉树上同一层最多的结点个数,节…

ssm旅游攻略网站设计+jsp

系统包含:源码论文 所用技术:SpringBootVueSSMMybatisMysql 需要源码或者定制看文章最下面或看我的主页 目 录 目 录 III 1 绪论 1 1.1 研究背景 1 1.2 目的和意义 1 1.3 论文结构安排 2 2 相关技术 3 2.1 SSM框架介绍 3 2.2 B/S结构介绍 3 …

Qt类的提升(Python)

from PyQt5.QtWidgets import QPushButtonclass apushbutton(QPushButton):def __init__(self, parentNone):super().__init__(parent)self.setText("Custom Button")self.setStyleSheet("background-color: yellow;")上述为一个“模板类”,命名…

kubernetes上安装kubesphere

准备工作 需要配置三台虚拟机 关闭防火墙 systemctl stop firewalldsystemctl disable firewalld 临时关闭selinux setenforce 0 永久关闭selinux vi /etc/selinux/config 安装docker rpm -qa|grep docker yum remove docker* -y rpm -qa|grep docker yum install -y yum-u…

Windows图形界面(GUI)-QT-C/C++ - QT控件创建管理初始化

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 目录 控件创建 包含对应控件类型头文件 实例化控件类对象 控件设置 设置父控件 设置窗口标题 设置控件大小 设置控件坐标 设置文本颜色和背景颜色 控件排版 垂直布局 QVBoxLayout …

分页工具代码重构

文章目录 1.common-mybatis-plus-starter1.目录2.PageInfo.java3.PageResult.java4.SunPageHelper.java 1.common-mybatis-plus-starter 1.目录 2.PageInfo.java package com.sunxiansheng.mybatis.plus.page;import lombok.EqualsAndHashCode; import lombok.ToString;impor…

Vue学习二——创建登录页面

前言 以一个登录页面为例子,这篇文章简单介绍了vue,element-plus的一些组件使用,vue-router页面跳转,pinia及持久化存储,axios发送请求的使用。后面的页面都大差不差,也都这么实现,只是内容&am…

初始Django框架

初识Django Python知识点:函数、面向对象。前端开发:HTML、CSS、JavaScript、jQuery、BootStrap。MySQL数据库。Python的Web框架: Flask,自身短小精悍 第三方组件。Django,内部已集成了很多组件 第三方组件。【主要…

深度学习每周学习总结R4(LSTM-实现糖尿病探索与预测)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客R6中的内容,为了便于自己整理总结起名为R4🍖 原作者:K同学啊 | 接辅导、项目定制 目录 0. 总结1. LSTM介绍LSTM的基本组成部分如何理解与应用LSTM 2. 数据预处理3. 数…

虚假星标:GitHub上的“刷星”乱象与应对之道

在开源软件的世界里,GitHub无疑是最重要的平台之一。它不仅是一个代码托管平台,也是一个社交网络,允许开发者通过“点赞”(即加星)来表达对某个项目的喜爱和支持,“星标”(Star)则成…

前端笔记----

在我的理解里边一切做页面的代码都是属于前端代码。 之前用过qt框架,也是用来写界面的,但是那是用来写客户端的,而html是用来写web浏览器的,相较之下htmlcssJavaScript写出来的界面是更加漂亮的。这里就记录我自个学习后的一些笔…

【面试题】技术场景 4、负责项目时遇到的棘手问题及解决方法

工作经验一年以上程序员必问问题 面试题概述 问题为在负责项目时遇到的棘手问题及解决方法,主要考察开发经验与技术水平,回答不佳会影响面试印象。提供四个回答方向,准备其中一个方向即可。 1、设计模式应用方向 以登录为例,未…

2025华数杯国际赛A题完整论文讲解(含每一问python代码+数据+可视化图)

大家好呀,从发布赛题一直到现在,总算完成了2025“华数杯”国际大学生数学建模竞赛A题Can He Swim Faster的完整的成品论文。 本论文可以保证原创,保证高质量。绝不是随便引用一大堆模型和代码复制粘贴进来完全没有应用糊弄人的垃圾半成品论文…

关闭window10或11自动更新和自带杀毒

关闭window10或11自动更新和自带杀毒 1.关闭系统更新**修改组策略关闭自动更新****修改服务管理器关闭自动更新** 2.关闭系统杀毒 为什么需要关闭更新和杀毒 案例: #装完驱动隔一段时间就掉 #一些设置隔一段时间就重置了 #防止更新系统后有时卡 1.关闭系统更新 我…

解析OVN架构及其在OpenStack中的集成

引言 随着云计算技术的发展,虚拟化网络成为云平台不可或缺的一部分。为了更好地管理和控制虚拟网络,Open Virtual Network (OVN) 应运而生。作为Open vSwitch (OVS) 的扩展,OVN 提供了对虚拟网络抽象的支持,使得大规模部署和管理…

【ArcGIS技巧】如何给CAD里的面注记导入GIS属性表中

前面分享了GIS怎么给田块加密高程点,但是没有分享每块田的高程对应的是哪块田,今天结合土地整理软件GLAND做一期田块的属性怎么放入GIS属性表当中。 1、GLAND数据 杭州阵列软件(GLand)是比较专业的土地整理软件,下载之…

Excel中SUM求和为0?难道是Excel有Bug!

大家好,我是小鱼。 在日常工作中有时会遇到这样的情况,对Excel表格数据进行求和时,结果竟然是0,很多小伙伴甚至都怀疑是不是Excel有Bug!其实,在WPS的Excel表格中数据求和,结果为0无法正确求和的…

Spring MVC简单数据绑定

【图书介绍】《SpringSpring MVCMyBatis从零开始学(视频教学版)(第3版)》_springspringmvcmybatis从零开始 代码、课件、教学视频与相关软件包下载-CSDN博客 《SpringSpring MVCMyBatis从零开始学(视频教学版)(第3版&…

蓝桥杯备考:数据结构之栈 和 stack

目录 栈的概念以及栈的实现 STL 的stack 栈和stack的算法题 栈的模板题 栈的算法题之有效的括号 验证栈序列 后缀表达式 括号匹配 栈的概念以及栈的实现 栈是一种只允许在一端进行插入和删除的线性表 空栈:没有任何元素 入栈:插入元素消息 出…

使用Dify创建个问卷调查的工作流

为啥要使用Dify创建工作流呢?一个基于流程的智能体的实现,特别是基于业务的实现,使用Dify去实现时,通常都是一个对话工作流,当设计到相对复杂一些的流程时,如果将所有逻辑都放在对话工作流中去实现&#xf…