1-简单回归问题

news2025/1/12 23:17:37

一.梯度下降(gradient descent)
1.预测函数
这里有一组样本点,横纵坐标分别代表一组有因果关系的变量

在这里插入图片描述

我们的任务是设计一个算法,让机器能够拟合这些数据,帮助我们算出参数w

在这里插入图片描述

我们可以先随机选一条过原点的直线,然后计算所有点到该直线的偏离程度(即误差)

在这里插入图片描述

再根据误差大小调整直线的斜率w,这里的y=wx就是预测函数

2.损失函数(loss function)/代价函数(cost function)
对于一个点(x1,y1),误差e1=y1-wx1

在这里插入图片描述

这里使用最小平方误差(Ordinary Least Squares, OLS),即将误差平方

在这里插入图片描述

将所有点的误差平方再展开,其中x1,y1,n均为已知数

在这里插入图片描述

将其相加求平均,再合并同类项
其中a>0

在这里插入图片描述

即可表示为

在这里插入图片描述

即代价函数cost/损失函数loss

在这里插入图片描述

在这里插入图片描述

这样就完成了预测函数到代价函数的映射过程,随着左图w的增大,右图的点向右移动

在这里插入图片描述

3.梯度计算
我们的目的是实现损失最小,即抛物线取得最小值时参数w的值
假设起始点在曲线上任意一处,寻找最低点的过程就是梯度下降
选择的下降方向是切线方向/梯度的反方向/陡峭程度最大的方向

梯度(gradient)是代价函数的导数
在这里插入图片描述

4.学习率(learning rate)
每一次更新参数利用多少误差, 就需要通过一个参数来控制, 这个参数就是学习率,也称为步长。
选择最优学习率是很重要的,因为它决定了我们是否可以迅速收敛到全局最小值。
小的学习率需要多次更新才能达到最低点,并且需要花费很多时间,且很容易仅收敛到局部极小值
学习率过大会导致剧烈的更新,可能总是在全局最小值附近,但是从未收敛到全局最小值
最佳学习率迅速达到最低点

在这里插入图片描述

每次新的w=旧的w-斜率*学习率
其中斜率=f导=梯度

在这里插入图片描述
循环迭代过程:定义代价函数→选择起始点→计算梯度→按学习率前进→计算梯度→按学习率前进→…到达最低点

二.线性回归(Linear Regression)
线性回归是一种统计学和机器学习中常用的预测方法,用于建立一个自变量(或称为特征)与因变量之间的线性关系模型。它假设自变量和因变量之间存在一个线性关系,并尝试通过拟合一条最佳拟合直线(或超平面)来进行预测。线性回归的目标是通过最小化预测值与实际观测值之间的差异(误差或残差)来找到最佳拟合直线或超平面。

在简单线性回归中,只有一个自变量和一个因变量之间的关系。这可以表示为一条直线的方程:y=wx+b
通过该预测函数我们可以得到误差 e=(wx+b-y)²
误差求和得到损失函数loss
在这里插入图片描述
目的是找到最小loss(error)时w和b的值

在这里插入图片描述

对于一个二元一次方程组,我们通常使用两式相减的方式求出参数b和w的值。这种可以精确求解的我们叫做闭合解(Closed-form Solution,也称封闭解)
在这里插入图片描述
但实际数据是有误差的,我们只能求得近似解
即实际的y=wx+b+ε,这里的ε叫做高斯噪声,由于高斯噪声的存在使得数据有误差
通过x和y的多组数据可以使结果更接近Closed-form Solution
在这里插入图片描述

下面使用代码实现二元一次方程组的求解

数据下载:提取码:zn73
数据点集合的每一行表示一个数据点,第一列是自变量 xi,第二列是因变量 yi

在这里插入图片描述

1.计算线性回归模型的误差函数
代码通过迭代遍历每个数据点,计算该数据点在回归模型下的预测值与实际观测值之间的误差。然后将每个误差的平方累加到总误差 totalError 中。最后,通过将总误差除以数据点的数量,计算出平均误差并返回。

在这里插入图片描述

通过索引操作 points[i,0],我们可以获取第 i 个数据点的自变量 x 的值,因为它位于每行的第一列(索引为 0)
类似地,通过 points[i,1],我们可以获取第 i 个数据点的因变量 y 的值,因为它位于每行的第二列(索引为 1)
b: 回归模型的截距。
w: 回归模型的斜率。
points: 数据点的集合,其中每个数据点由自变量 x 和因变量 y 组成。

def compute_error_for_line_given_points(b,w,points):
    totalError=0
    for i in range(0,len(points)):
        x=points[i,0]
        y=points[i,1]
        totalError+=(y-(w*x+b))**2
    return totalError/float(len(points))

2.梯度下降中的参数更新
首先初始化截距梯度 b_gradient 和斜率梯度 w_gradient 为 0。然后,通过迭代遍历每个数据点,计算每个数据点对应的梯度值,以便在下一步更新中使用。对于每个数据点,根据当前的截距和斜率计算出预测值,然后根据预测值与实际观测值之间的误差来计算梯度。最后,将所有数据点的梯度累加到总梯度中,并除以数据点的数量 N,以获得平均梯度。接下来使用梯度下降的更新规则来更新截距和斜率。根据当前的截距和斜率值,分别减去学习率乘以对应的梯度,得到新的截距 new_b 和斜率 new_w。最后,将更新后的截距和斜率作为列表返回。

在这里插入图片描述
b_current: 当前的截距值。
w_current: 当前的斜率值。
points: 数据点的集合,其中每个数据点由自变量 x 和因变量 y 组成。
learningRate: 学习率,用于控制每次更新的步长。

def step_gradient(b_current,w_current,points,learningRate):
    b_gradient=0
    w_gradient=0
    N=float(len(points))
    for i in range(0,len(points)):
        x=points[i,0]
        y=points[i,1]
        b_gradient+=(2*(w_current*x+b_current-y))/N  # 对b偏导
        w_gradient+=(2*(w_current*x+b_current-y)*x)/N  # 对w偏导
    new_b=b_current-learningRate*b_gradient
    new_w=w_current-learningRate*w_gradient
    return [new_b,new_w]

3.梯度下降算法的主要循环部分
points: 数据点的集合,其中每个数据点由自变量 x 和因变量 y 组成。
starting_b: 初始的截距值。
starting_w: 初始的斜率值。
learning_rate: 学习率,用于控制每次更新的步长。
num_iterations: 迭代次数,表示要运行梯度下降的步骤数。

import numpy as np
def gradient_decent_runner(points,starting_b,starting_w,learing_rate,num_iterations):
    b=starting_b
    w=starting_w
    for i in range(num_iterations):
        b,w=step_gradient(b,w,np.array(points),learing_rate)
    return [b,w]  # 返回最后一次迭代结果,即最终数据

4.运行

def run():
    points=np.genfromtxt("data.csv",delimiter=",")  # data.csv更换为文件的存放地址
    learning_rate=0.0001
    initial_b=0
    initial_w=0
    num_iterations=1000
    print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b,initial_w,compute_error_for_line_given_points(initial_b,initial_w,points)))
    [b,w]=gradient_descent_runner(points,initial_b,initial_w,learning_rate,num_iterations)
    print("After {0} interations b={1},w={2},error={3}".format(num_iterations,b,w,compute_error_for_line_given_points(b,w,points)))

这里的np.genfromtxt 是 NumPy 库中的一个函数,用于从文本文件加载数据并生成一个 NumPy 数组。该函数可以处理各种格式的文本数据,包括逗号分隔值(CSV)文件和具有不同分隔符的文件。
例如:存在一个名为 ‘data.csv’ 的 CSV 文件,其中的数据使用逗号作为分隔符。np.genfromtxt 函数将加载该文件的数据并生成一个 NumPy 数组,存储在变量 data 中,使用 print(data) 即可打印加载的数据。

import numpy as np
# 从名为 'data.csv' 的 CSV 文件中加载数据
data = np.genfromtxt('data.csv', delimiter=',')
# 打印加载的数据
print(data)

完整代码

import numpy as np
def compute_error_for_line_given_points(b,w,points):
    totalError=0
    for i in range(0,len(points)):
        x=points[i,0]
        y=points[i,1]
        totalError+=(y-(w*x+b))**2
    return totalError/float(len(points))
def step_gradient(b_current,w_current,points,learningRate):
    b_gradient=0
    w_gradient=0
    N=float(len(points))
    for i in range(0,len(points)):
        x=points[i,0]
        y=points[i,1]
        b_gradient+=(2*(w_current*x+b_current-y))/N
        w_gradient+=(2*(w_current*x+b_current-y)*x)/N
    new_b=b_current-learningRate*b_gradient
    new_w=w_current-learningRate*w_gradient
    return [new_b,new_w]
def gradient_descent_runner(points,starting_b,starting_w,learing_rate,num_iterations):
    b=starting_b
    w=starting_w
    for i in range(num_iterations):
        b,w=step_gradient(b,w,np.array(points),learing_rate)
    return [b,w]
def run():
    points=np.genfromtxt("D:/Deep-Learning-with-PyTorch-Tutorials/lesson04-简单回归案例实战/data.csv",delimiter=",")
    learning_rate=0.0001
    initial_b=0
    initial_w=0
    num_iterations=1000
    print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b,initial_w,compute_error_for_line_given_points(initial_b,initial_w,points)))
    [b,w]=gradient_descent_runner(points,initial_b,initial_w,learning_rate,num_iterations)
    print("After {0} interations b={1},w={2},error={3}".format(num_iterations,b,w,compute_error_for_line_given_points(b,w,points)))
run()

运行结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/665445.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【GESP】2023年03月图形化一级 -- 小猫捉老鼠

文章目录 小猫捉老鼠1. 准备工作2. 功能实现3. 设计思路与实现(1)角色、舞台背景设置a. 角色设置b. 舞台背景设置 (2)脚本编写a. 角色:Mouse1b. 角色:Cat 2 4. 评分标准 小猫捉老鼠 1. 准备工作 &#xff…

Vue3项目中使用vue-router

目录 1、Vue Router 的主要概念和功能2、什么是 vue-router?3、为什么需要 vue-router?4、基本概念和安装4.1 了解单页面应用(SPA)和路由的基本概念4.1.1单页面应用(Single Page Application,SPA)4.1.2路由…

【运维知识进阶篇】zabbix5.0稳定版详解3(监控Nginx+PHP服务状态信息)

这篇文章继续给大家介绍zabbix监控,监控Nginx、PHP等服务,其实非常简单,难点在于如何去取这个值,包括监控业务,难点在于思路是否清晰,思维是否活跃,如何去进行判断是否有这个业务,并…

小白到运维工程师自学之路 第三十四集 (redis的基本使用)

一、概念 Redis是一个开源的内存数据结构存储系统,它可以用作数据库、缓存和消息中间件。Redis支持多种数据结构,如字符串、哈希表、列表、集合、有序集合等。Redis的特点是数据存储在内存中,因此读写速度非常快,同时也支持数据持…

【Vue3+Ts project】认识 @vueuse/core 库

目标: 根据屏幕宽度改变 实现动态获取盒子的宽度 目录 目标: 一、javascript实现 二、vueuse/core 库实现 一、javascript实现 1.首先 window.innerWidth 获取当前屏幕宽度,然后将 盒子宽度 除 375 乘 当前屏幕宽度 150 / 375 * window.innerWidth 2.将获取的…

千万不要跟随这 4 种领导!

​ 见字如面,我是军哥! 最近有程序员读者问我,什么样的领导不能跟随?都有哪些坑!这个我擅长哈,毕竟职场混迹 15 年~ 第一种,技术能力不行还喜欢指手画脚的领导。 第二种,…

鹏云网络分布式块存储社区版问世,首发开源存储解决方案

2023年1月,南京鹏云网络科技有限公司(简称:鹏云网络)正式宣布开源ZettaStor DBS分布式块存储系统,开放了自研10余年的分布式块存储技术,自此踏上了“自研”与“开源”一体并行的生态闭环之路。 研发十年&am…

python程序获取最新的行政区划名称代码

一、实现目标 最近由于项目需要,需要获取最新的过去全国县以上行政区划的名称和代码。网上虽然有一些资料,但是不是需要积分就是需要会员,而且担心这些资料不是最新的。因此,想着使用程序从官方网站上获取最新的全国行政区划数据。 二、实现思路 1、找到官方最新发布的全国…

c++11 标准模板(STL)(std::basic_ios)(五)

定义于头文件 <ios> template< class CharT, class Traits std::char_traits<CharT> > class basic_ios : public std::ios_base 类 std::basic_ios 提供设施&#xff0c;以对拥有 std::basic_streambuf 接口的对象赋予接口。数个 std::basic_ios…

【夜深人静学数据结构与算法 | 第七篇】时间复杂度与空间复杂度

目录 前言&#xff1a; 引入&#xff1a; 时间复杂度&#xff1a; 案例&#xff1a; 空间复杂度&#xff1a; 案例&#xff1a; TIPS&#xff1a; 总结&#xff1a; 前言&#xff1a; 今天我们将来介绍时间复杂度和空间复杂度&#xff0c;我们代码的优劣就是依…

力扣算法刷题Day38|动态规划:斐波那契数 爬楼梯 使用最小花费爬楼梯

力扣题目&#xff1a;#509. 斐波那契数 刷题时长&#xff1a;参考答案后5min 解题方法&#xff1a;动态规划 复杂度分析 时间O(n)空间O(n) 问题总结 无 本题收获 动规五部曲思路 确定dp数组以及下标的含义&#xff1a;dp[i]的定义为&#xff0c;第i个数的斐波那契数值…

VMware虚拟机彻底卸载详细教程

VMware虚拟机彻底卸载 一、彻底卸载过程1.1 停止VMware服务1.2 结束vmware任务1.3 开始卸载VMware1.4 删除注册表信息1.5 删除安装目录 二、vmware 安装教程三、vmware 使用教程 回到目录   回到末尾 一、彻底卸载过程 卸载之前&#xff0c;需要先关闭VMware相关的后台服务…

软件技巧:7款冷门且十分良心的软件

1、Okular 阅读器 Okular是一款来自KDE的通用文档阅读器&#xff0c;支持众多文档格式&#xff0c;如PDF、Postscript、DjVu、CHM、XPS、ePub、图片格式、漫画格式等&#xff0c;支持Windows、macOS与Linux&#xff0c;是科研学术人士阅读文献的好工具&#xff0c;也是电子书爱…

OWASP 之认证崩溃基础技能

文章目录 一、burp爆破用法1.Attack type爆破方式设置2.payload处理3.请求引擎设置4.攻击结果设置5.grap匹配设置 二、常见端口与利用1、文件共享2、远程连接3、Web应用4、数据库 三、爆破案例经验1、暴力破解攻击产生的5个原因或漏洞2、猜测用户名方法3、猜测密码方法 四、实验…

亚马逊云科技中国峰会:Amazon DeepRacer——载着 AI 梦想向前奔跑

目录 一、Amazon DeepRacer 是什么&#xff1f; 二、Amazon DeepRacer 的前世今生 三、Amazon DeepRacer 深度体验 四、2023亚马逊云科技中国峰会 1.中国峰会总决赛 2.自动驾驶赛车名校邀请赛 3.Girls in Tech Show 4.全球联赛 5.报名链接&#xff1a; 一、Amazon Dee…

C++个人通信录管理系统

背景&#xff1a; 使用C编写一个个人通信录管理系统&#xff0c;来完成作业上的一些需求。 1-提供录入个人信息、修改个人信息&#xff08;姓名和出生日期除外&#xff09;、删除个人信息等编辑功能 2-提供按姓名查询个人信息的功能 3-提供查找在5天之内过生日的人员的信息…

【C++初阶】C++STL详解(二)—— string类的模拟实现

​ ​&#x1f4dd;个人主页&#xff1a;Sherry的成长之路 &#x1f3e0;学习社区&#xff1a;Sherry的成长之路&#xff08;个人社区&#xff09; &#x1f4d6;专栏链接&#xff1a;C初阶 &#x1f3af;长路漫漫浩浩&#xff0c;万事皆有期待 上一篇博客&#xff1a;【C初阶】…

Internet Relay Chat:mIRC 7.73 Crack

mIRC是一个流行的互联网中继聊天客户端&#xff0c;个人和组织使用它在世界各地的IRC网络上相互交流、共享、玩耍和工作。为互联网社区服务了20多年&#xff0c;mIRC已经发展成为一种强大、可靠和有趣的技术。 Latest News mIRC 7.73 has been released! (June 18th 2023) This…

Linux常用命令——fuser命令

在线Linux命令查询工具 fuser 使用文件或文件结构识别进程 补充说明 fuser命令用于报告进程使用的文件和网络套接字。fuser命令列出了本地进程的进程号&#xff0c;那些本地进程使用file&#xff0c;参数指定的本地或远程文件。对于阻塞特别设备&#xff0c;此命令列出了使…