【深度学习】线性回归

news2024/10/6 10:36:23

Linear Regression

  • 一个例子
  • 线性回归
  • 机器学习中的表达
  • 评价函数好坏的度量:损失(Loss)
  • 损失函数(Loss function)
    • 哪个数据集的均方误差 (MSE) 高
  • 如何找出最优b和w?
  • 寻找最优b和w
  • 如何降低损失 (Reducing Loss)
  • 梯度下降法
    • 梯度
    • 计算梯度f^'^(x)i
    • 梯度下降法(gradient descent)
    • 学习速率(Learning rate)
    • 金发女孩原则(Goldilocks principle)
  • 如何训练模型
  • 习题

一个例子

相比凉爽的天气,蟋蟀在较为炎热的天气里鸣叫更频繁。
现有数据如下图,请预测鸣叫声与温度的关系。
在这里插入图片描述

线性回归

如图这种由点确定线的过程叫回归,既找出因变量和自变量之间的关系

y = mx + b

y指的是温度(摄氏度),即我们试图预测的值
m指的是直线的斜率
x指的是每分钟的鸣叫声次数,即输入特征的值
b指的是y轴截距

机器学习中的表达

y=𝑤𝑥+𝑏

x 特征(feature)
y 预测值(target)
w 权重(weight)
b 偏差(bias)

在这里插入图片描述
样本(samples) (x,y)

在这里插入图片描述

如何找出最优的函数 , 即找出一组最佳的w和b?
首先需要对函数的好坏进行评价

评价函数好坏的度量:损失(Loss)

反应模型预测(prediction)的准确程度。

如果模型(model)的预测完全准确,则损失为零
训练模型的目标是从所有样本中找到一组平均损失 “较小” 的权重和偏差。

在这里插入图片描述

显然,相较于左侧曲线图中的蓝线,右侧曲线图中的蓝线代表的是预测效果更好的模型

损失函数(Loss function)

损失函数是指汇总各样本损失的数学函数。

· 平方损失(又称为 L2 损失)
· 单个样本的平方损失= ( observation - prediction(x) )2 = ( y - y’ )2
· 均方误差 (mean-square error, MSE) 指的是样本的平均平方损失

在这里插入图片描述

哪个数据集的均方误差 (MSE) 高

以下两幅图显示的两个数据集,哪个数据集的均方误差 (MSE) 较高?

在这里插入图片描述

B图线上的 8 个样本产生的总损失为 0。不过,尽管只有两个点在线外,但这两个点的离线距离依然是左图中离群点的 2 倍。平方损失进一步加大差异,因此两个点的偏移量产生的损失是一个点的 4 倍。
· 根据损失函数的定义
· 对于A:MSE = (02 + 12 + 02 + 12 + 02 + 12 + 02 + 12 + 02 + 02)/10 = 0.4
· 对于B:MSE = (02 + 02 + 02 + 22 + 02 + 02 + 02 + 22 + 02 + 02)/10 = 0.8

→因此B的MSE较高。

如何找出最优b和w?

定义了损失函数,我们就可以评价任一函数的好坏,下一步如何找出最优b和w?
靠猜~
在这里插入图片描述

像猜价格游戏,参与者给出初始价钱,通过“高了”或“低了”的提示,逐渐接近正确的价格。
在这里插入图片描述

寻找最优b和w

1. 首先随机给出一组参数b=1,w=1 
2. 评价这组参数的好坏,例如用MSE 
3. 改变w和b的值
4. 转到步骤2,直到总体损失不再变化或变化极其缓慢为止,该模型已收敛

在这里插入图片描述

最后一个问题:如何改变w、b ???

如何降低损失 (Reducing Loss)

简化问题,以只有一个参数w为例,所产生的损失Loss与 w 的图形是凸形(convex)。如下所示:

在这里插入图片描述

只有一个最低点 → 即只存在一个斜率正好为 0 的位置。
损失函数取到最小值的地方。

但是,如何找到这一点呢?

  1. 为w选择一个起点↓ 这里选择了一个稍大于 0 的起点
    在这里插入图片描述

梯度下降法

梯度

· 梯度是偏导数的矢量;有方向和大小。

· 梯度即是某一点最大的方向导数,沿梯度方向函数有最大的变化率
(沿梯度方向函数增加,负梯度方向函数减少)

· 损失Loss相对于单个权重的梯度大小就等于导数f(x)i

二元函数的梯度:
在这里插入图片描述

计算梯度f(x)i

求导数,即切线的斜率,在这个例子中,是负的,因此负梯度是w的正方向。

在这里插入图片描述

梯度下降法(gradient descent)

w=w-lr*w_grad

lr (learning rate,学习速率)

在这里插入图片描述

然后,重复此过程,逐渐接近最低点。
在这里插入图片描述
⭐负梯度指向损失函数下降的方向

学习速率(Learning rate)

通常梯度下降法用梯度乘以学习速率(步长),以确定下一个点的位置:

w=w-lr*w_grad

例如,如果学习速率为 0.01,梯度大小为 2.5,则:
w=w-0.01*2.5

学习速率是机器学习算法中人为引入的,是用于调整机器学习算法的旋钮,这种称为超参数。
在这里插入图片描述

金发女孩原则(Goldilocks principle)

每个回归问题都存在一个“Goldilocks principle”学习速率,该值与损失函数的平坦程度相关。 例如,如果损失函数的梯度较小,则可以采用更大的学习速率,以补偿较小的梯度并获得更大的步长。

(西方有一个儿童故事叫 “ The Three Bears(金发女孩与三只小熊)”,迷路了的金发姑娘未经允许就进入了熊的房子,她尝了三只碗里的粥,试了三把椅子,又在三张床上躺了躺。最后发现不烫不冷的粥最可口,不大不小的椅子坐着最舒服,不高不矮的床上躺着最惬意。道理很简单,刚刚好就是最适合的,just the right amount,这样做选择的原则被称为 Goldilocks principle(金发女孩原则)。)

采取恰当学习速率,可以高效的到达最低点,如下图所示:

在这里插入图片描述
降低损失:优化学习速率-模拟动图

如何训练模型

  1. 定义一个函数的集合
  2. 给出评价函数好坏的方法
  3. 利用梯度下降法找到最佳函数

在这里插入图片描述

习题

在这里插入图片描述

import numpy as np
import matplotlib.pyplot as plt


def gradient_descent(x, y, theta, learning_rate, epochs):
    ws = []
    bs = []
    for i in range(epochs):
        y_pred = x.dot(theta)
        diff = y_pred - y
        loss = 0.5 * np.mean(diff ** 2)
        g = x.T.dot(diff)
        theta -= learning_rate * g
        ws.append(theta[0][0])
        bs.append(theta[1][0])
        learning_rate = learning_rate_shedule(i + 1)
        print(f'第{i + 1}次梯度下降后,损失为{round(loss, 5)},w为{round(theta[0][0], 5)},b为{round(theta[1][0], 5)}')
        if loss < 0.001:
            break
    return ws, bs


def learning_rate_shedule(t):
    return t0 / (t + t1)



 

if __name__ == '__main__':
    xdata = np.array([8, 3, 9, 7, 16, 5, 3, 10, 4, 6])
    ydata = np.array([30, 21, 35, 27, 42, 24, 10, 38, 22, 25])
    x_data = np.array(xdata).reshape(-1, 1)
    y_data = np.array(ydata).reshape(-1, 1)
    X = np.concatenate([x_data, np.full_like(x_data, fill_value=1)], axis=1)
    theta = np.random.randn(2, 1)
    t0 = 1.5
    t1 = 1000
    ws, bs = gradient_descent(X, y_data, theta, learning_rate_shedule(0), 10000)
    ax = plt.subplot(3, 2, 1)
    bx = plt.subplot(3, 2, 2)
    cx = plt.subplot(3, 1, (2, 3))

    # 散点+预测线
    ax.scatter(x_data, y_data)
    x = np.linspace(x_data.min() - 1, x_data.max() + 1, 100)
    y = ws[-1] * x + bs[-1]
    ax.plot(x, y, color='red')
    ax.set_xlabel('x')
    ax.set_ylabel('y')

    #w,b变化
    x = np.arange(1, len(ws) + 1)
    bx.plot(x, ws, label='w')
    bx.plot(x, bs, label='b')
    bx.set_xlabel('epoch')
    bx.set_ylabel('change')
    bx.legend()

   #等高线
    def get_loss(w, b):
        return 0.5 * np.mean((y_data - (w * x_data + b)) ** 2)
    b_range = np.linspace(-100, 100, 100)
    w_range = np.linspace(-10, 10, 100)
    losses = np.zeros((len(b_range), len(w_range)))
    for i in range(len(b_range)):
        for j in range (len(w_range)):
            losses[i, j] = get_loss(w_range[j], b_range[i])
    cx.contour(b_range, w_range, losses, cmap='summer')
    cx.contourf(b_range, w_range, losses)
    cx.set_xlabel('b')
    cx.set_ylabel('w')
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1509756.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【毕设级项目】基于AI技术的多功能消防机器人(完整工程资料源码)

基于AI技术的多功能消防机器人演示效果 竞赛-基于AI技术的多功能消防机器人视频演示 前言 随着“自动化、智能化”成为数字时代发展的关键词&#xff0c;机器人逐步成为社会经济发展的重要主体之一&#xff0c;“机器换人”成为发展的全新趋势和时代潮流。在可预见的将来&#…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《计及台区资源聚合功率的中低压配电系统低碳优化调度方法》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

软件无线电系列——软件无线电的发展历程及体系框架

本节目录 一、软件无线电的起始 二、软件无线电SDR论坛 三、SPEAKeasy计划 四、JTRS与SCA 五、软件无线电体系框架本节内容 一、软件无线电的起始 1992年5月&#xff0c;美国电信会议上&#xff0c;Joseph Mitola III博士提出来软件无线电(Software Radio,SR)的概念。理想化的…

实现支持多选的QComboBox

Qt提供的QComboBox只支持下拉列表内容的单选&#xff0c;但通过QComboBox提供的setModel、setView、setLineEdit三个方法&#xff0c;可以对QComboBox进行改造&#xff0c;使其实现下拉列表选项的多选。 QComboBox可以看作两个组件的组合&#xff1a;一个QLineEdit和一个QList…

OpenCV开发笔记(七十七):相机标定(二):通过棋盘标定计算相机内参矩阵矫正畸变摄像头图像

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/136616551 各位读者&#xff0c;知识无穷而人力有穷&#xff0c;要么改需求&#xff0c;要么找专业人士&#xff0c;要么自己研究 红胖子(红模仿)的博…

Orange3数据预处理(预处理器组件)

1.组件介绍 Orange3 提供了一系列的数据预处理工具&#xff0c;这些工具可以帮助用户在数据分析之前准备好数据。以下是您请求的预处理组件的详细解释&#xff1a; Discretize Continuous Variables&#xff08;离散化连续变量&#xff09;&#xff1a; 这个组件将连续变量转…

利用Nginx正向代理实现局域网电脑访问外网

引言 在网络环境中&#xff0c;有时候我们需要让局域网内的电脑访问外网&#xff0c;但是由于网络策略或其他原因&#xff0c;直接访问外网是不可行的。这时候&#xff0c;可以借助 Nginx 来搭建一个正向代理服务器&#xff0c;实现局域网内电脑通过 Nginx 转发访问外网的需求…

算法(结合算法图解)

算法简介简单查找二分查找法 选择排序内存的工作原理数组和链表数组选择排序小结 递归小梗 要想学会递归&#xff0c;首先要学会递归。 递归的基线条件和递归条件递归和栈小结 快速排序分而治之快速排序合并排序时间复杂度的平均情况和最糟情况小结 算法简介 算法是一组完成任…

Python3虚拟环境之virtualenv

virtualenv 在开发Python应用程序的时候&#xff0c;系统安装的Python3只有一个版本&#xff1a;3.7。所有第三方的包都会被pip安装到Python3的site-packages目录下。 如果要同时开发多个应用程序&#xff0c;这些应用程序都会共用一个Python&#xff0c;就是安装在系统的Pyt…

【算法】一维前缀和以及二维前缀和

目录 一维前缀和适用场景示例 二维前缀和适用场景一种情况另一种情况示例 一维前缀和 适用场景 求一段区间的和。 比如有一个数列 &#xff1a; 如果我们要求 [l,r]即某个区间内的数组和的时候&#xff0c;思路就是每遍历一个元素就进行求和&#xff0c;记录下加到al时的和…

HYBBS 表白墙网站PHP程序源码,支持封装成APP

PHP表白墙网站源码&#xff0c;适用于校园内或校区间使用&#xff0c;同时支持封装成APP。告别使用QQ空间的表白墙。 简单安装&#xff0c;只需PHP版本5.6以上即可。 通过上传程序进行安装&#xff0c;并设置账号密码&#xff0c;登录后台后切换模板&#xff0c;适配手机和PC…

Java双非大二找实习记录

先说结论&#xff1a;2.22→3.6线上线下面了七家&#xff0c;最后oc两家小公司&#xff0c;接了其中一个。 本人bg&#xff1a; 真名不经传双非一本&#xff0c;无绩点无竞赛无奖项无实习&#xff0c;23年12月开始学java。若非要说一点相关的经历&#xff0c;就是有java基础&…

python-0003-pycharm开发虚拟环境中的项目

前言 在虚拟环境中创建好了python项目&#xff0c;使用pycharm进行开发 打开项目 使用pycharm打开项目 设置虚拟环境的解释器 File–>Settings–>Project(项目名)–>Python Interpreter–>添加解释器–>添加已经存在的解释器–>选择虚拟环境的解释器 …

程序人生——Java开发中通用的方法和准则,Java进阶知识汇总

目录 引出Java开发中通用的方法和准则建议1:不要在常量和变量中出现易混淆的字母建议2:莫让常量蜕变成变量建议3:三元操作符的类型务必一致建议4:避免带有变长参数的方法重载建议5:别让null值和空值威胁到变长方法建议6:覆写变长方法也循规蹈矩建议7:警惕自增的陷阱建议…

linux查看文件内容cat,less,vi,vim

学习记录 目录 catlessvi vim cat 输出 FILE 文件的全部内容 $ cat [OPTION] FILE示例 输出 file.txt 的全部内容 $ cat file.txt查看 file1.txt 与 file2.txt 连接后的内容 $ cat file1.txt file2.txt为什么名字叫 cat&#xff1f; 当然和猫咪没有关系。 cat 这里是 co…

Python 导入Excel三维坐标数据 生成三维曲面地形图(面) 4-2、线条平滑曲面(原始颜色)但不去除无效点

环境和包: 环境 python:python-3.12.0-amd64包: matplotlib 3.8.2 pandas 2.1.4 openpyxl 3.1.2 scipy 1.12.0 代码: import pandas as pd import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from scipy.interpolate import griddata fr…

未来城市:数字孪生技术助力智慧城市构建

目录 一、数字孪生技术的兴起与定义 二、数字孪生技术在智慧城市构建中的应用 1、城市规划与管理 2、智慧交通 3、智慧能源 4、智慧环保 三、数字孪生技术助力智慧城市构建的挑战与对策 四、结语 随着科技的飞速发展&#xff0c;未来城市正在经历一场前所未有的变革。数…

Redis 的 key 的过期策略是怎么实现的【经典面试题】

前言 在 Redis 中可以通过命令 expire 对指定的 key 值设置过期时间&#xff0c;在时间到了以后该键值对就会自动删除。 一个 Redis 中可能会存在很多的 key &#xff0c;而这些 key 中有很大的一部分都会有过期时间&#xff0c;那么 Redis 怎么知道哪些 key 已经到了过期时间需…

【C语言】InfiniBand驱动mlx4_register_interface函数

一、讲解 mlx4_register_interface函数是Mellanox InfiniBand驱动程序的一部分&#xff0c;这个函数的作用是注册一个新的接口(intf)到InfiniBand设备。这允许不同的子系统&#xff0c;如以太网或存储&#xff0c;能够在同一个硬件设备上注册它们各自需要的接口&#xff0c;在…

Vue3全家桶 - VueRouter - 【6】导航守卫

导航守卫 查看以下情形&#xff1a; 点击主页链接时&#xff0c;默认情况下可直接进入指定页面&#xff0c;如下图&#xff0c;但是问题是该跳转的界面是需要用户登录后方可访问的&#xff1b; 可设置导航守卫来检测用户是否登录&#xff0c;如果已登录&#xff0c;则进入后台…