深度学习 回归问题

news2024/11/15 12:32:20

1. 梯度下降算法

深度学习中, 梯度下降算法是是一种很重要的算法.

梯度下降算法与求极值的方法非常类似, 其核心思想是求解 x ′ x' x, 使得 x ′ x' x 在取 x ⋆ x^{\star} x 时, 可以使得 l o s s 函数 loss函数 loss函数 的值最小.

其中, 在求解 x ′ x' x 的过程中, 采用的是迭代的方法, 不断迭代逼近 $ x^{\star}$. 最基本的公式为:
x ′ = x − l r × ▽ x x' = x - lr \times \triangledown x x=xlr×x

其中 ▽ x \triangledown x x l o s s ′ ∣ x loss'|_{x} lossx , l r lr lr 为学习率, 以上述公式为基础,发展出了更多的求解器.

2. 噪声

在现实世界中, 数据总是会存在误差.
y = w ∗ x + b + ϵ ϵ ∼ N ( 0.01 , 1 ) y = w * x + b + \epsilon \enspace\enspace \epsilon \sim {N(0.01, 1)} y=wx+b+ϵϵN(0.01,1)

l o s s = ( W X + b − y ) 2 loss = (WX + b - y)^2 loss=(WX+by)2

3. 回归与分类

3.1 线性回归

预测范围为实数区间.

3.2 逻辑回归

加了压缩函数后, 压缩了预测范围[0, 1].

3.3 分类

如手写数字识别.

4. 优化

y = w x + b + ϵ y = wx + b + \epsilon y=wx+b+ϵ中, 通过已有的 x i x_i xi y i y_i yi求解 w w w ϵ \epsilon ϵ, 可以优化为以下问题:

在这里插入图片描述

5. 回归问题实践

5.1 计算给定点的误差

代码如下所示:

def compute_error_for_line_given_points(b, w, points):
    totalError = 0
    for i in range(0, len(points)):
        x = points[i, 0]  # 获取当前点的x坐标
        y = points[i, 1]  # 获取当前点的y坐标
        # 计算预测值与实际值之间的差的平方,并累加到总误差中
        totalError += (y - (w * x + b)) ** 2
        # 返回平均误差
    return totalError / float(len(points))

5.2 计算梯度下降的梯度, 更新b和w

在这里插入图片描述

代码如下所示:

def step_gradient(b_current, w_current, points, learningRate):
    b_gradient = 0
    w_gradient = 0
    N = float(len(points))  # 点的总数
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        # 计算b和w的梯度
        # 梯度计算
        b_gradient += -(2 / N) * (y - (w_current * x + b_current))
        w_gradient += -(2 / N) * x * (y - (w_current * x + b_current))
    # 使用学习率更新b和w
    new_b = b_current - (learningRate * b_gradient)
    new_w = w_current - (learningRate * w_gradient)
    return [new_b, new_w]

5.3 执行梯度下降算法, 迭代b和w

代码如下所示:

def gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations):
    b = starting_b
    m = starting_m  # 通常w用于表示斜率,但这里用m,可能是为了与初始变量名保持一致
    for i in range(num_iterations):
        b, m = step_gradient(b, m, np.array(points), learning_rate)
    return [b, m]

5.4 完整代码

import torch  # 导入torch库,但在此代码中未直接使用
import numpy as np  # 导入numpy库,用于处理数值数据


# 计算给定直线(由参数b和w定义)对于一组点的误差
def compute_error_for_line_given_points(b, w, points):
    totalError = 0
    for i in range(0, len(points)):
        x = points[i, 0]  # 获取当前点的x坐标
        y = points[i, 1]  # 获取当前点的y坐标
        # 计算预测值与实际值之间的差的平方,并累加到总误差中
        totalError += (y - (w * x + b)) ** 2
        # 返回平均误差
    return totalError / float(len(points))


# 计算梯度下降中的梯度,并更新直线参数b和w
def step_gradient(b_current, w_current, points, learningRate):
    b_gradient = 0
    w_gradient = 0
    N = float(len(points))  # 点的总数
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        # 计算b和w的梯度
        # 梯度计算
        b_gradient += -(2 / N) * (y - (w_current * x + b_current))
        w_gradient += -(2 / N) * x * (y - (w_current * x + b_current))
    # 使用学习率更新b和w
    new_b = b_current - (learningRate * b_gradient)
    new_w = w_current - (learningRate * w_gradient)
    return [new_b, new_w]


# 执行梯度下降算法以优化直线参数
def gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations):
    b = starting_b
    m = starting_m  # 通常w用于表示斜率,但这里用m,可能是为了与初始变量名保持一致
    for i in range(num_iterations):
        b, m = step_gradient(b, m, np.array(points), learning_rate)
    return [b, m]


# 主函数,用于运行梯度下降算法
def run():
    points = np.genfromtxt("data.csv", delimiter=",")  # 从CSV文件加载数据点
    learning_rate = 0.0001  # 设置学习率
    initial_b = 0  # 初始截距
    initial_m = 0  # 初始斜率(这里用m代替w)
    num_iterations = 1000  # 设置迭代次数
    # 在开始梯度下降之前,计算并打印初始误差
    print("Starting gradient descent at b = {0}, w = {1}, error = {2}"
          .format(initial_b, initial_m,
                  compute_error_for_line_given_points(initial_b, initial_m, points)))
    print("Running...")
    # 执行梯度下降
    [b, m] = gradient_descent_runner(points, initial_b, initial_m, learning_rate, num_iterations)
    # 打印梯度下降后的结果和最终误差
    print("After {0} iterations b = {1}, w ={2}, error = {3}"
          .format(num_iterations, b, m,
                  compute_error_for_line_given_points(b, m, points))
          )


if __name__ == '__main__':
    run()  # 调用主函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2071676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

罗德与施瓦茨RS、UPV 音频分析仪 250KHZ 双通道分析仪UPL

罗德与施瓦茨 UPV 音频分析仪的规格包括&#xff1a; 模拟 双通道分析仪&#xff1a;带宽高达 250 kHz 生成正弦波信号&#xff1a;单通道最高 185 kHz&#xff08;需要 B1&#xff09;和双通道最高 80 kHz FFT本底噪声&#xff1a;< -140dB 固有频率响应&#xff08;20 …

链动 2+1 模式小程序 AI 智能名片商城源码培训邀约策略研究

摘要&#xff1a;本文深入剖析链动 21 模式小程序 AI 智能名片商城源码的培训邀约策略&#xff0c;从该源码的价值出发&#xff0c;阐述邀约的重要性&#xff0c;并详细介绍具体的邀约策略&#xff0c;旨在为相关培训活动提供切实可行的指导&#xff0c;提高邀约成功率&#xf…

前端如何快速切换node版本:nvm

安装之前最好卸载计算机已经安装的node&#xff08;通过Windows菜单找到Node.js的卸载程序&#xff0c;运行卸载程序&#xff09;。下载nvm安装包&#xff1a;nvm安装地址。安装nvm&#xff0c;选择nvm安装根路径指定nodejs的安装路径打开命令行&#xff0c;输入nvm -v 可查看版…

Object.create的原型继承

● 首先我们来从这种方法来创建一个和之前一样计算年龄的方法 const PersonProto {cacleAge() {console.log(2038 - birthYear);} };const zhangsan Object.create(PersonProto); console.log(zhangsan);● 发现确实可以实现原型继承的特性 const PersonProto {cacleAge()…

odoo17 group col 属性

odoo17 group col 属性 以前版本&#xff0c;col4,在17中不能用了&#xff0c;或者方法变了 <record id"hetong.addfj_wizard" model"ir.ui.view"><field name"name">合同附件</field><field name"model">het…

免费的大模型插件llm.nvim

llm.nvim&#xff08;https://github.com/StubbornVegeta/llm.nvim&#xff09;是一款基于cloudflare的免费大模型插件&#xff0c;你可以像使用ChatGPT一样和它进行对话 在使用这款插件之前&#xff0c;你需要注册cloudflare&#xff0c;获取你的account和API key。你可以在这…

RCE - - 无字母数字远程命令执行

题目源码 <?php if(isset($_GET[code])){$code $_GET[code];if(strlen($code)>35){die("Long.");}if(preg_match("/[A-Za-z0-9_$]/",$code)){die("NO.");}eval($code); }else{highlight_file(__FILE__); } 分析 这道题 code 接 get 传…

【Qt】常用控件QProgressBar

常用控件QProgressBar 使用QProgressBar表示一个进度条&#xff01;&#xff01;&#xff01; QProgressBar的核心属性 属性说明 minimum 进度条最⼩值 maximum 进度条最⼤值 value 进度条当前值 alignment ⽂本在进度条中的对⻬⽅式. Qt::AlignLeft : 左对⻬Qt::Alig…

AJAX(4)——XMLHttpRequest

XMLHttpRequest 定义&#xff1a;XMLHttpRequest(XHR)对象用于与服务器交互。通过XMLHttpRequest可以在不刷新页面的情况下请求特定URL&#xff0c;获取数据。这允许网页在不影响用于操作的情况下&#xff0c;更新页面的局部内容。XMLHttpRequest在AJAX编程中被大量使用 关系…

第6章 B+树索引

目录 6.1 没有索引的查找 6.1.1 在一个页中的查找 6.1.2 在很多页中查找 6.2 索引 6.2.1 一个简单的索引方案 6.2.2 InnoDB中的索引方案 6.2.2.1 聚簇索引 6.2.2.2 二级索引 6.2.2.3 联合索引 6.2.3 InnoDB的B树索引的注意事项 6.2.3.1 根页面万年不动窝 6.2.3.2 内节…

MYSQL————数据库的约束

1.约束类型 1.not null&#xff1a;指示某列不能存储null值 2.unique&#xff1a;保证某列的每行必须有唯一值 3.default&#xff1a;规定没有给列赋值时的默认值 4.primary key&#xff1a;not null和unique的结合。确保某列&#xff08;或两个或多个列的结合&#xff09;有唯…

qtcreator的vim模式下commit快捷键ctrl+g,ctrl+c没有反应的问题

首先开启vim后&#xff0c;CtrlG&#xff0c;CtrlC无法用 解决&#xff1a; 工具 -> 选项->FakeVim 转到Ex Command Mapping 搜索Commit 底栏Regular expression 输入commit &#xff08;理论上可以是随意的单词&#xff09; 设置好后&#xff0c;以后要运行&#x…

vue+uniapp

#vue支持的语法&#xff0c;基本上可以做uniapp中所使用&#xff08;指绝大部分&#xff09; #知识点&#xff1a;插值表达式&#xff0c;响应式&#xff0c;指令&#xff0c;事件&#xff0c;指令修饰符 #拥有一些案例&#xff0c;补充&#xff0c;以及说明了如何在vscode运…

如何在 Android 智能手机上恢复已删除的图片

面对现实&#xff0c;从手机图库中丢失照片总是令人不安的&#xff0c;无论您是无意中删除了它们&#xff0c;还是甚至出于冲动而生气。但是&#xff0c;我们在这里告诉您&#xff0c;与大多数人的看法相反&#xff0c;从画廊中删除图像并不会使它们不可挽回地丢失。以下是一些…

【MySQL进阶之路】内外链接

目录 内连接 外连接 左外连接 右外连接 个人主页&#xff1a;东洛的克莱斯韦克-CSDN博客 内连接 内连接实际上就是利用where子句对两种表形成的笛卡儿积进行筛选 select 字段 from 表1 inner join 表2 on 连接条件 and 其他条件&#xff1b; 外连接 外连接分为左外连接和…

【Java】—— 数组元素的查找:顺序查找与二分查找

目录 1、顺序查找 2、二分查找 1、顺序查找 在Java编程中&#xff0c;我们经常需要查找数组中某个元素的下标。有时&#xff0c;我们需要找到该元素第一次出现的位置&#xff0c;而有时则需要找到最后一次出现的位置。在本文中&#xff0c;我们将重点介绍如何查找元素第一次出…

AI依赖的隐患:技术能力退化、安全风险与社会不平等的未来

现代科技的浪潮中&#xff0c;ChatGPT等人工智能工具已经成为我们工作和生活的得力助手。然而&#xff0c;当这种便利变成了依赖&#xff0c;潜在的风险开始显现。过度依赖AI不仅可能导致技术能力的严重退化&#xff0c;还可能加剧信息安全问题和社会不平等。让我们深度剖析这三…

智慧社区信息系统建设:数据可视化与原型设计的力量

在数字化浪潮的推动下&#xff0c;智慧社区作为城市治理现代化的重要一环&#xff0c;正以前所未有的速度改变着我们的生活方式。智慧社区信息系统&#xff0c;作为支撑这一变革的核心&#xff0c;不仅要求高效的数据处理能力&#xff0c;还需具备直观的数据展示与强大的用户交…

zdppy+vue3+onlyoffice文档管理系统实战 20240823上课笔记 zdppy_cache框架的低代码实现

遗留问题 1、封装API2、有账号密码3、查询所有有效的具体数据&#xff0c;也就是缓存的所有字段 封装查询所有有效具体数据的方法 基本封装 def get_all(self, is_activeTrue, limit100000):"""遍历数据库中所有的key&#xff0c;默认查询所有没过期的:para…

服务器数据恢复—重建RAID失败导致数据丢失的数据恢复案例

服务器数据恢复环境&#xff1a; 某品牌服务器中有一组由4块SAS磁盘做的RAID5磁盘阵列。该服务器操作系统为windows server&#xff0c;运行了一个单节点Oracle&#xff0c;数据存储为文件系统&#xff0c;无归档。该oracle数据库的数据量不大&#xff0c;oracle数据库内只有一…