机器学习代码基础——ML2 使用梯度下降的线性回归

news2025/4/18 8:16:09

ML2 使用梯度下降的线性回归

牛客网

描述

编写一个使用梯度下降执行线性回归的 Python 函数。该函数应将 NumPy 数组 X(具有一列截距的特征)和 y(目标)作为输入,以及学习率 alpha 和迭代次数,并返回一个 NumPy 数组,表示线性回归模型的系数。

输入描述:

第1行输入X,第2行输入y,第3行输入alpha,第4行输入迭代次数。

输出描述:

输出线性回归模型的系数,四舍五入到小数点后四位。返回类型是List类型。

输入:
[[1, 1], [1, 2], [1, 3], [1, 4]]
[2, 3, 4, 5]
0.01
1000

输出: 
[0.8678 1.045 ]
import numpy as np
def linear_regression_gradient_descent(X, y, alpha, iterations):
    # 补全代码
    m,n = X.shape
    theta = np.zeros((n,1)) # 为了和答案一致
    for _ in range(iterations):
        y_predict = X@theta
        errors = y_predict - y
        discent = X.T@(errors)/m
        theta = theta - alpha * discent
    return np.round(theta.flatten(), 4)

# 主程序
if __name__ == "__main__":
    # 输入矩阵和向量
    matrix_inputx = input()
    array_y = input()
    alpha = input()
    iterations = input()

    # 处理输入
    import ast
    matrix = np.array(ast.literal_eval(matrix_inputx))
    y = np.array(ast.literal_eval(array_y)).reshape(-1,1)
    alpha = float(alpha)
    iterations = int(iterations)

    # 调用函数计算逆矩阵
    output = linear_regression_gradient_descent(matrix,y,alpha,iterations)
    
    # 输出结果
    print(output)


[0.8678 1.045 ]

梯度下降求解

梯度下降是一种计算局部最小值的一种方法。梯度下降思想就是给定一个初始值𝜃,每次沿着函数梯度下降的方向移动𝜃:

θ ( t + 1 ) : = θ ( t ) − α ∇ θ J ( θ ( t ) ) \theta^{(t+1)} := \theta^{(t)} - \alpha \nabla_{\theta} J(\theta^{(t)}) θ(t+1):=θ(t)αθJ(θ(t))

在梯度为零或趋近于零的时候收敛
J ( θ ) = 1 2 n ∑ i = 1 n ( x i T θ − y i ) 2 J(\theta)=\frac{1}{2n}\sum^n_{i=1}(x_i^T\theta-y_i)^2 J(θ)=2n1i=1n(xiTθyi)2
对损失函数求偏导可得到 (n个样本,每个样本p维)
x i = ( x i , 0 , . . . , x i , p ) T x i j 表示第 i 个样本的第 j 个分量 ∂ θ j 1 2 n ( x i T θ − y i ) 2 = ∂ θ j 1 2 n ( ∑ j = 0 p x i , j θ j − y i ) 2 = 1 n ( ∑ j = 0 p x i , j θ j − y i ) x i , j = 1 n ( f ( x i ) − y i ) ) x i , j ∇ θ J = [ J θ 0 J θ 1 . . . J θ p ] x_i=(x_{i,0},...,x_{i,p})^T\\ x_{ij}表示第i个样本的第j个分量\\ \frac{\partial}{\theta_j}\frac{1}{2n}(x_i^T\theta-y_i)^2=\frac{\partial}{\theta_j}\frac{1}{2n}(\sum^p_{j=0}x_{i,j}\theta_j-y_i)^2=\frac{1}{n}(\sum^p_{j=0}x_{i,j}\theta_j-y_i)x_{i,j}=\frac{1}{n}(f(x_i)-y_i))x_{i,j} \\ \nabla_\theta J=\begin{bmatrix} \frac{J}{\theta_0}\\ \frac{J}{\theta_1}\\...\\ \frac{J}{\theta_p} \end{bmatrix} xi=(xi,0,...,xi,p)Txij表示第i个样本的第j个分量θj2n1(xiTθyi)2=θj2n1(j=0pxi,jθjyi)2=n1(j=0pxi,jθjyi)xi,j=n1(f(xi)yi))xi,jθJ= θ0Jθ1J...θpJ
对于只有一个训练样本的训练组而言,每走一步,𝜃𝑗(𝑗= 0,1,…,𝑝)的更新公式就可以写成:
θ j ( t + 1 ) : = θ j ( t ) − α ∂ ∂ θ j J ( θ j ( t ) ) = θ j ( t ) − α 1 n ( f ( x i ) − y i ) x i , j \theta_j^{(t+1)} := \theta_j^{(t)} - \alpha \frac{\partial}{\partial \theta_j} J(\theta_j^{(t)}) = \theta_j^{(t)} - \alpha \frac{1}{n} (f(x_i) - y_i) x_{i,j} θj(t+1):=θj(t)αθjJ(θj(t))=θj(t)αn1(f(xi)yi)xi,j
因此,当有 n 个训练实例的时候(批处理梯度下降算法),该公式就可以写为:
θ j ( t + 1 ) : = θ j ( t ) − α 1 n ∑ i = 1 n ( f ( x i ) − y i ) x i , j \theta_j^{(t+1)}:=\theta_j^{(t)}-\alpha\frac{1}{n}\sum^n_{i=1}(f(x_i)-y_i)x_{i,j} θj(t+1):=θj(t)αn1i=1n(f(xi)yi)xi,j
这样,每次根据所有数据求出偏导,然后根据特定的步长𝛼,就可以不断更新𝜃𝑗,直到其收敛。当梯度为0或目标函数值不能继续下降的时候,就可以说已经收敛,即目标函数达到局部最小值。

具体过程可以归纳如下

1️⃣ 初始化𝜃(随机初始化)

2️⃣ 利用如下公式更新𝜃
θ j ( t + 1 ) : = θ j ( t ) − α 1 n ∑ i = 1 n ( f ( x i ) − y i ) x i , j θ ( t + 1 ) : = θ ( t ) − α 1 n ∑ i = 1 n ( f ( x i ) − y i ) x i \theta_j^{(t+1)}:=\theta_j^{(t)}-\alpha \frac{1}{n}\sum^n_{i=1}(f(x_i)-y_i)x_{i,j}\\ \theta^{(t+1)}:=\theta^{(t)}-\alpha \frac{1}{n}\sum^n_{i=1}(f(x_i)-y_i)x_{i} θj(t+1):=θj(t)αn1i=1n(f(xi)yi)xi,jθ(t+1):=θ(t)αn1i=1n(f(xi)yi)xi
其中α为步长

3️⃣ 如果新的𝜃能使𝐽(𝜃)继续减少,继续利用上述步骤更新𝜃,否则收敛,停止迭代。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2329689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PostgreSQL 一文从安装到入门掌握基本应用开发能力!

本篇文章主要讲解 PostgreSQL 的安装及入门的基础开发能力,包括增删改查,建库建表等操作的说明。navcat 的日常管理方法等相关知识。 日期:2025年4月6日 作者:任聪聪 一、 PostgreSQL的介绍 特点:开源、免费、高性能、关系数据库、可靠性、稳定性。 官网地址:https://w…

WEB安全--内网渗透--LMNTLM基础

一、前言 LM Hash和NTLM Hash是Windows系统中的两种加密算法,不过LM Hash加密算法存在缺陷,在Windows Vista 和 Windows Server 2008开始,默认情况下只存储NTLM Hash,LM Hash将不再存在。所以我们会着重分析NTLM Hash。 在我们内…

8.用户管理专栏主页面开发

用户管理专栏主页面开发 写在前面用户权限控制用户列表接口设计主页面开发前端account/Index.vuelangs/zh.jsstore.js 后端Paginator概述基本用法代码示例属性与方法 urls.pyviews.py 运行效果 总结 欢迎加入Gerapy二次开发教程专栏! 本专栏专为新手开发者精心策划了…

室内指路机器人是否支持与第三方软件对接?

嘿,你知道吗?叁仟室内指路机器人可有个超厉害的技能,那就是能和第三方软件 “手牵手” 哦,接下来就带你一探究竟! 从技术魔法角度看哈:好多室内指路机器人都像拥有超能力的小魔法师,采用开放式…

从代码上深入学习GraphRag

网上关于该算法的解析都停留在大概流程上,但是具体解析细节未知,由于代码是PipeLine形式因此阅读起来比较麻烦,本文希望通过阅读项目代码来解析其算法的具体实现细节,特别是如何利用大模型来完成图谱生成和检索增强的实现细节。 …

【Redis】通用命令

使用者通过redis-cli客户端和redis服务器交互,涉及到很多的redis命令,redis的命令非常多,我们需要多练习常用的命令,以及学会使用redis的文档。 一、get和set命令(最核心的命令) Redis中最核心的两个命令&…

微前端随笔

✨ single-spa: js-entry 通过es-module 或 umd 动态插入 js 脚本 ,在主应用中发送请求,来获取子应用的包, 该子应用的包 singleSpa.registerApplication({name: app1,app: () > import(http://localhost:8080/app1.js),active…

C++中的浅拷贝和深拷贝

浅拷贝只是将变量的值赋予给另外一个变量,在遇到指针类型时,浅拷贝只会把当前指针的值,也就是该指针指向的地址赋予给另外一个指针,二者指向相同的地址; 深拷贝在遇到指针类型时,会先将当前指针指向地址包…

车载诊断架构 --- 整车重启先后顺序带来的思考

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 周末洗了一个澡,换了一身衣服,出了门却不知道去哪儿,不知道去找谁,漫无目的走着,大概这就是成年人最深的孤独吧! 旧人不知我近况,新人不知我过…

【C++11(下)】—— 我与C++的不解之缘(三十二)

前言 随着 C11 的引入,现代 C 语言在语法层面上变得更加灵活、简洁。其中最受欢迎的新特性之一就是 lambda 表达式(Lambda Expression),它让我们可以在函数内部直接定义匿名函数。配合 std::function 包装器 使用,可以…

Windows 10/11系统优化工具

家庭或工作电脑使用时间久了,会出现各种各样问题,今天给大家推荐一款专为Windows 10/11系统设计的全能优化工具,该软件集成了超过40项专业级实用程序,可针对系统性能进行深度优化、精准调校、全面清理、加速响应及故障修复。通过系…

浅谈在HTTP中GET与POST的区别

从 HTTP 报文来看: GET请求方式将请求信息放在 URL 后面,请求信息和 URL 之间以 ?隔开,请求信息的格式为键值对,这种请求方式将请求信息直接暴露在 URL 中,安全性比较低。另外从报文结构上来看&#xff0c…

LightRAG实战:轻松构建知识图谱,破解传统RAG多跳推理难题

作者:后端小肥肠 🍊 有疑问可私信或评论区联系我。 🥑 创作不易未经允许严禁转载。 姊妹篇: 2025防失业预警:不会用DeepSeek-RAG建知识库的人正在被淘汰_deepseek-embedding-CSDN博客 从PDF到精准答案:Coze…

C++多线程编码二

1.lock和try_lock lock是一个函数模板,可以支持多个锁对象同时锁定同一个,如果其中一个锁对象没有锁住,lock函数会把已经锁定的对象解锁并进入阻塞,直到多个锁锁定一个对象。 try_lock也是一个函数模板,尝试对多个锁…

垃圾回收——三色标记法(golang使用)

三色标记法(tricolor mark-and-sweep algorithm)是传统 Mark-Sweep 的一个改进,它是一个并发的 GC 算法,在Golang中被用作垃圾回收的算法,但是也会有一个缺陷,可能程序中的垃圾产生的速度会大于垃圾收集的速度,这样会导…

Windows环境下开发pyspark程序

Windows环境下开发pyspark程序 一、环境准备 1.1. Anaconda/Miniconda(Python环境) 如果不怕包的版本管理混乱,可以直接使用已有的Python环境。 需要安装anaconda/miniconda(python3.8版本以上):Anaconda…

SSM婚纱摄影网的设计

🍅点赞收藏关注 → 添加文档最下方联系方式咨询本源代码、数据库🍅 本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目希望你能有所收获,少走一些弯路。🍅关注我不迷路🍅 项目视频 SS…

1110+款专业网站应用程序UI界面设计矢量图标figma格式素材 Icon System | 1,100+ Icons Easily Customize

1110款专业网站应用程序UI界面设计矢量图标figma格式素材 Icon System | 1,100 Icons Easily Customize 产品特点 — 24 x 24 px 网格大小 — 2px 线条描边 — 所有形状都是基于矢量的 — 平滑和圆角 — 易于更改颜色 类别 🚨 警报和反馈 ⬆️ 箭头 &…

Llama 4 家族:原生多模态 AI 创新的新时代开启

0 要点总结 Meta发布 Llama 4 系列的首批模型,帮用户打造更个性化多模态体验Llama 4 Scout 是有 170 亿激活参数、16 个专家模块的模型,同类中全球最强多模态模型,性能超越以往所有 Llama 系列模型,能在一张 NVIDIA H100 GPU 上运…

正则表达式(Regular Expression,简称 Regex)

一、5w2h(七问法)分析正则表达式 是的,5W2H 完全可以应用于研究 正则表达式(Regular Expressions)。通过回答 5W2H 的七个问题,我们可以全面理解正则表达式的定义、用途、使用方法、适用场景等&#xff0c…