【机器学习】单变量线性回归

news2025/1/12 6:11:48

文章目录

  • 线性回归模型(linear regression model)
  • 损失/代价函数(cost function)——均方误差(mean squared error)
  • 梯度下降算法(gradient descent algorithm)
  • 参数(parameter)和超参数(hyperparameter)
  • 代码实现样例
  • 运行结果

源代码文件请点击此处!

线性回归模型(linear regression model)

  • 线性回归模型:

f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b(x)=wx+b

其中, w w w 为权重(weight), b b b 为偏置(bias)

  • 预测值(通常加一个帽子符号):

y ^ ( i ) = f w , b ( x ( i ) ) = w x ( i ) + b \hat{y}^{(i)} = f_{w,b}(x^{(i)}) = wx^{(i)} + b y^(i)=fw,b(x(i))=wx(i)+b

损失/代价函数(cost function)——均方误差(mean squared error)

  • 一个训练样本: ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i))
  • 训练样本总数 = m m m
  • 损失/代价函数是一个二次函数,在图像上是一个开口向上的抛物线的形状。

J ( w , b ) = 1 2 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] 2 = 1 2 m ∑ i = 1 m [ w x ( i ) + b − y ( i ) ] 2 \begin{aligned} J(w, b) &= \frac{1}{2m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}]^2 \\ &= \frac{1}{2m} \sum^{m}_{i=1} [wx^{(i)} + b - y^{(i)}]^2 \end{aligned} J(w,b)=2m1i=1m[fw,b(x(i))y(i)]2=2m1i=1m[wx(i)+by(i)]2

  • 为什么需要乘以 1/2?因为对平方项求偏导后会出现系数 2,是为了约去这个系数。

梯度下降算法(gradient descent algorithm)

  • α \alpha α:学习率(learning rate),用于控制梯度下降时的步长,以抵达损失函数的最小值处。若 α \alpha α 太小,梯度下降太慢;若 α \alpha α 太大,下降过程可能无法收敛。
  • 梯度下降算法:

r e p e a t { t m p _ w = w − α ∂ J ( w , b ) w t m p _ b = b − α ∂ J ( w , b ) b w = t m p _ w b = t m p _ b } u n t i l   c o n v e r g e \begin{aligned} repeat \{ \\ & tmp\_w = w - \alpha \frac{\partial J(w, b)}{w} \\ & tmp\_b = b - \alpha \frac{\partial J(w, b)}{b} \\ & w = tmp\_w \\ & b = tmp\_b \\ \} until \ & converge \end{aligned} repeat{}until tmp_w=wαwJ(w,b)tmp_b=bαbJ(w,b)w=tmp_wb=tmp_bconverge

其中,偏导数为

∂ J ( w , b ) w = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] x ( i ) ∂ J ( w , b ) b = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] \begin{aligned} & \frac{\partial J(w, b)}{w} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] x^{(i)} \\ & \frac{\partial J(w, b)}{b} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] \end{aligned} wJ(w,b)=m1i=1m[fw,b(x(i))y(i)]x(i)bJ(w,b)=m1i=1m[fw,b(x(i))y(i)]

参数(parameter)和超参数(hyperparameter)

  • 超参数(hyperparameter):训练之前人为设置的任何数量都是超参数,例如学习率 α \alpha α
  • 参数(parameter):模型在训练过程中创建或修改的任何数量都是参数,例如 w , b w, b w,b

代码实现样例

import numpy as np
import matplotlib.pyplot as plt

# 计算误差均方函数 J(w,b)
def cost_function(x, y, w, b):
    m = x.shape[0] # 训练集的数据样本数
    cost_sum = 0.0
    for i in range(m):
        f_wb = w * x[i] + b
        cost = (f_wb - y[i]) ** 2
        cost_sum += cost
    return cost_sum / (2 * m)

# 计算梯度值 dJ/dw, dJ/db
def compute_gradient(x, y, w, b):
    m = x.shape[0] # 训练集的数据样本数
    d_w = 0.0
    d_b = 0.0
    for i in range(m):
        f_wb = w * x[i] + b
        d_wi = (f_wb - y[i]) * x[i]
        d_bi = (f_wb - y[i])
        d_w += d_wi
        d_b += d_bi
    dj_dw = d_w / m
    dj_db = d_b / m
    return dj_dw, dj_db

# 梯度下降算法
def linear_regression(x, y, w, b, learning_rate=0.01, epochs=1000):
    J_history = [] # 记录每次迭代产生的误差值
    for epoch in range(epochs):
        dj_dw, dj_db = compute_gradient(x, y, w, b)
        # w 和 b 需同步更新
        w = w - learning_rate * dj_dw
        b = b - learning_rate * dj_db
        J_history.append(cost_function(x, y, w, b)) # 记录每次迭代产生的误差值
    return w, b, J_history

# 绘制线性方程的图像
def draw_line(w, b, xmin, xmax, title):
    x = np.linspace(xmin, xmax)
    y = w * x + b
    # plt.axis([0, 10, 0, 50]) # xmin, xmax, ymin, ymax
    plt.xlabel("X-axis", size=15)
    plt.ylabel("Y-axis", size=15)
    plt.title(title, size=20)
    plt.plot(x, y)

# 绘制散点图
def draw_scatter(x, y, title):
    plt.xlabel("X-axis", size=15)
    plt.ylabel("Y-axis", size=15)
    plt.title(title, size=20)
    plt.scatter(x, y)

# 从这里开始执行
if __name__ == '__main__':
    # 训练集样本
    x_train = np.array([1, 2, 3, 5, 6, 7])
    y_train = np.array([15.5, 19.7, 24.4, 35.6, 40.7, 44.8])
    w = 0.0 # 权重
    b = 0.0 # 偏置
    epochs = 10000 # 迭代次数
    learning_rate = 0.01 # 学习率
    J_history = [] # # 记录每次迭代产生的误差值

    w, b, J_history = linear_regression(x_train, y_train, w, b, learning_rate, epochs)
    print(f"result: w = {w:0.4f}, b = {b:0.4f}") # 打印结果

    # 绘制迭代计算得到的线性回归方程
    plt.figure(1)
    draw_line(w, b, 0, 10, "Linear Regression")
    plt.scatter(x_train, y_train) # 将训练数据集也表示在图中
    plt.show()

    # 绘制误差值的散点图
    plt.figure(2)
    x_axis = list(range(0, 10000))
    draw_scatter(x_axis, J_history, "Cost Function in Every Epoch")
    plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1444236.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序(三十九)表单信息收集

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.表单收集的基本方法 2.picker的不足及解决方法 源码&#xff1a; index.wxml <!-- 用户信息 --> <view class"register"><!-- 绑定表单信息收集事件--><form bindsubmit"…

信号——block+pending+handler表

信号 注意 &#xff1a;这由三张表&#xff0c;block只能添加修改&#xff0c;pending只能获取 , handler只能修改 基础知识 抵达——> 执行 / 忽略sigset_t 信号集被阻塞的信号产生时将保持在未决状态,直到进程解除对此信号的阻塞,才执行递达的动作 信号集操作 #include &…

第75讲Avatar头像FooterHome实现

Avatar头像实现 avatar&#xff1a; <template><el-dropdown><span class"el-dropdown-link"><el-avatar shape"square" :size"40" :src"squareUrl" /></span><template #dropdown><el-drop…

【MySQL进阶之路】生产案例:数据库无法连接,Too many connections

欢迎关注公众号&#xff08;通过文章导读关注&#xff1a;【11来了】&#xff09;&#xff0c;及时收到 AI 前沿项目工具及新技术的推送&#xff01; 在我后台回复 「资料」 可领取编程高频电子书&#xff01; 在我后台回复「面试」可领取硬核面试笔记&#xff01; 文章导读地址…

6 scala-面向对象编程基础

Scala 跟 Java 一样&#xff0c;是一门面向对象编程的语言&#xff0c;有类和对象的概念。 1 类与对象 与 Java 一样&#xff0c;Scala 也是通过关键字 class 来定义类&#xff0c;使用关键字 new 创建对象。 要运行我们编写的代码&#xff0c;同样像 Java 一样&#xff0c;…

C#,最大公共子序列(LCS,Longest Common Subsequences)的算法与源代码

1 最大公共子序列 最长的常见子序列问题是寻找两个给定字符串中存在的最长序列。 最大公共子序列算法&#xff0c;常用于犯罪鉴定、亲子鉴定等等的 DNA 比对。 1.1 子序列 让我们考虑一个序列S<s1&#xff0c;s2&#xff0c;s3&#xff0c;s4&#xff0c;…&#xff0c;…

项目02《游戏-13-开发》Unity3D

基于 项目02《游戏-12-开发》Unity3D &#xff0c; 任务 &#xff1a;宠物系统 及 人物头像血条 首先在主面板MainPanel预制体中新建一个Panel&#xff0c; 命名为PlayerInfo 新建Image&#xff0c;作为头像 新建Slider&#xff0c;作为血条 对Panel组件添加一个水…

中年低端中产程序员从西安出发到海南三亚低成本吃喝万里行:西安-南宁-湛江-雷州-徐闻-博鳌-陵水-三亚-重庆-西安

文章大纲 旅途规划来回行程的确定南宁 - 北海 - 湛江轮渡成为了最终最大的不确定性&#xff01;感谢神州租车气温与游玩地点总体花费 游玩过程出发时间&#xff1a;Day1-1月25日星期四&#xff0c;西安飞南宁路途中&#xff1a;Day2-1月26日星期五&#xff0c;南宁-湛江-住雷州…

算法学习——LeetCode力扣二叉树篇1

算法学习——LeetCode力扣二叉树篇1 144. 二叉树的前序遍历 144. 二叉树的前序遍历 - 力扣&#xff08;LeetCode&#xff09; 描述 给你二叉树的根节点 root &#xff0c;返回它节点值的 前序 遍历。 示例 示例 1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&a…

猫头虎分享已解决Bug | Kotlin Error: Unresolved reference: name

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

Java多态原理

参考 虚方法 JVM杂记&#xff1a;对多态实现原理、虚方法表、虚方法、静态解析、动态链接的一些思考_多态和方法表的关系-CSDN博客 静态分派与动态分派 &#xff08;JVM&#xff09;Java虚拟机&#xff1a;静态分派 & 动态分派 原理解析 - 掘金 虚方法表 JVM 栈帧&am…

python+flask+django医院预约挂号系统6nrhh

医院预约挂号系统主要有管理员、用户和医生三个功能模块。以下将对这三个功能的作用进行详细的剖析。 技术栈 后端&#xff1a;python 前端&#xff1a;vue.jselementui 框架&#xff1a;django/flask Python版本&#xff1a;python3.7 数据库&#xff1a;mysql5.7 数据库工具…

python 基础知识点(蓝桥杯python科目个人复习计划35)

今日复习计划&#xff1a;阶段总结&#xff08;新年贺礼&#xff09; 1.python简介&#xff08;定义&#xff0c;优点&#xff0c;缺点&#xff0c;应用领域&#xff09; python&#xff1a;一种广泛使用的解释型&#xff0c;高级和通用的编程语言 python极简&#xff0c;生…

猫头虎分享已解决Bug || KeyError: ‘The truth value of a Series is ambiguous‘

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

【MySQL】-12 MySQL索引(上篇MySQL索引类型前置-1)

MySQL索引 索引1 索引基础2 索引与优化1 选择索引的数据类型1.1 选择标识符 2 索引入门2.1 索引的类型2.1.1 B-Tree索引2.1.2 Hash索引2.1.3 空间(R-Tree)索引2.1.4 全文(Full-text)索引 索引的优点&#xff1a;索引是最好的解决方案吗&#xff1f; 索引 索引&#xff08;在MYS…

08-OpenFeign-结合Sentinel,实现熔断降级

当我们在对服务远程调用时&#xff0c;会因为服务的请求超时、抛出异常等情况&#xff0c;导致调用失败。 如果短时间内&#xff0c;产生大量请求异常。引发上游的调用方请求积压&#xff0c;最终会引起整个调用链雪崩。 为此我们需要对核心的调用过程进行监控&#xff0c;当…

前端JavaScript篇之ajax、axios、fetch的区别

目录 ajax、axios、fetch的区别AjaxAxiosFetch总结注意 ajax、axios、fetch的区别 在Web开发中&#xff0c;ajax、axios和fetch都是用于与服务器进行异步通信的技术&#xff0c;但它们在实现方式和功能上有所不同。 Ajax 定义与特点&#xff1a;Ajax是一种在无需重新加载整个…

深度学习(14)--x.view()详解

在torch中&#xff0c;常用view()函数来改变tensor的形状 查询官方文档&#xff1a; torch.Tensor.view — PyTorch 2.2 documentationhttps://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view示例 1.创建一个4x4的二维数组进行测试 x torch.…

开局一个破碗的故事例子

在一个寒冷的冬日&#xff0c;一个瘦弱的小姑娘拿着一个破碗&#xff0c;孤独地走在被白雪覆盖的街道上。她的名字叫小梅&#xff0c;她的父母早逝&#xff0c;留下她一个人在这个世界上艰难地生活。 小梅的破碗里只有几个铜板&#xff0c;那是她前一天沿街乞讨所得&#xff0c…

Windows Server 2019 搭建并加入域

系列文章目录 目录 系列文章目录 文章目录 前言 一、域是什么&#xff1f; 二、配置服务器 1.实验环境搭建 1)实验服务器配置和客户端 2)实验环境 2.服务器配置 账户是域服务器的账户和密码 文章目录 Windows Server 2003 Web服务器搭建Windows Server 2003 FTP服务器搭…