机器学习吴恩达笔记第一篇——基于梯度下降的线性回归(零基础)

news2025/1/10 14:12:36

机器学习吴恩达笔记第一篇——基于梯度下降的线性回归(零基础)

一、线性回归——理论(单变量)

  • 1、 假设函数h(x)为:

h ( x ) = θ 0 + θ 1 X h(x)=\theta_0+\theta_1 X h(x)=θ0+θ1X

  • 2、要拟合数据成一条直线,需要该直线上的预测的h(x)的值减去原来数据的y值

m i n i z e θ 0 , θ 1 ( J ( θ ) ) = ∑ i = 1 m ( h θ ( x i ) − y i ) 2 / 2 m \mathop{minize}\limits_{\theta_0,\theta_1} (J(\theta))=\sum\limits_{i=1}^m(h_\theta(x^i)-y^i)^2/2m θ0,θ1minize(J(θ))=i=1m(hθ(xi)yi)2/2m

  • 3、通过寻找一个合适的学习率与迭代次数来求得对应的 θ 0 , θ 1 \theta_0 , \theta_1 θ0,θ1

​ repeat until end{ #迭代到结束
t e m p = θ j − α × ( ∂ J ( θ ) ∂ θ j ) temp=\theta_j-\alpha \times(\frac{\partial J(\theta)}{\partial \theta_j}) temp=θjα×(θjJ(θ))
θ j = t e m p \theta_j=temp θj=temp

}

  • 注意: θ j \theta_j θj需要同时更新;

  • 4、计算出来 θ \theta θ值后,我们来判断 J ( θ ) J(\theta) J(θ)是否收敛:就是绘制在迭代过程中 J ( θ ) J(\theta) J(θ)的值;

在这里插入图片描述

好了,理论部分就结束了,下一节就是介绍本节的基础知识;

二、基础知识(矩阵的乘法、梯度下降、学习率、迭代次数)

1.矩阵的乘法:

  • 行乘列,列要与行相等;
import numpy as np
a=[1,2,3] #1*3
b=[4,5,6] #1*3
a=np.array(a)
b=np.array(b)
c=a.T*b 
#a.T:3*1
#  b:1*3
print(c)
#[ 4 10 18]

2.梯度下降

​ **repeat until end{ ** #迭代到结束

θ j = θ j − α × ( ∂ J ( θ ) ∂ θ j ) \theta_j=\theta_j-\alpha \times(\frac{\partial J(\theta)}{\partial \theta_j}) θj=θjα×(θjJ(θ))

}

因为 J ( θ ) J(\theta) J(θ)是一个 ∑ i = 1 m ( h θ ( x i ) − y i ) 2 / 2 m \sum\limits_{i=1}^m(h_\theta(x^i)-y^i)^2/2m i=1m(hθ(xi)yi)2/2m所以要求和;

  • 使用迭代方法:
while i<iteration: #iteration=迭代次数;
    temp_0=sum((theta_0+theta_1*size+theta_2*num-price))        #theta_0
    temp_1=sum(size.T*(theta_0+theta_1*size+theta_2*num-price)) #theta_1
    i+=1
  • (theta_0+theta_1*size+theta_2*num-price)size.T*(theta_0+theta_1*size+theta_2*num-price)都是求得偏导;

3.学习率

  • 是决定收敛的重要条件:
  1. 收敛的时间长短:
  • α \alpha α: 学习率;

repeat until end{ #迭代到结束

θ j = θ j − α × ( ∂ J ( θ ) ∂ θ j ) \theta_j=\theta_j-\alpha \times(\frac{\partial J(\theta)}{\partial \theta_j}) θj=θjα×(θjJ(θ))

}
在这里插入图片描述

α \alpha α合适的时候会快速收敛;

  1. 是否收敛:

α \alpha α太大会不收敛,反复横跳;

4.迭代次数


当迭代次数过少,就会不收敛;
在这里插入图片描述

次数太少,没有收敛;

三、实现的代码

  1. Python库
import matplotlib.pyplot as plt
import numpy as np
import math
  1. 读取文件
with open('ex1data1.txt','r') as f: # 文件划分;
    lines=f.readlines()
x=[]
y=[]
data=[]
for line in lines:
    data1=line.strip().split(',')
    x.append(float(data1[0]))
    y.append(float(data1[1]))
    data.append([float(data1[0]),float(data1[1])])
print(data)
  1. 画散点图
plt.scatter(x,y,c='red',marker='+',s=50,label='scatter') #散点图;
plt.xlabel('population')
plt.ylabel('predict')
plt.show()

在这里插入图片描述
在这里插入图片描述

  1. 直线方程
def h(a,b,x): #函数方程;
    return a+b*x
  1. 拟合直线
m=len(x)#长度
rates=[0.012]#,0.003,0.01,0.03,0.1,0.3,1,3] #学习率;
iterations=[10000]#,30,100,300,1000]  #迭代次数;
s=[]
#"""
for rate in rates: #学习率;
    for iteration in iterations: # 迭代次数;
        i=0
        theta_0=0
        theta_1=0
        # 从0开始;
        J=[]
        flag=0
        while i<iteration:
            total=0
            temp_0=0
            temp_1=0
            for j,k in data:
                temp_0+=(theta_0+theta_1*j-k)  #J(theta_0)的偏导部分;
                temp_1+=(theta_0+theta_1*j-k)*j #J(theta_1)的偏导部分;
            #采集J(theta);
            #更新theta_0与theta_1;
            theta_0=theta_0-(temp_0*rate/m)
            theta_1=theta_1-(temp_1*rate/m)
            i+=1
            for j,k in data:
                total+=math.pow((theta_0+theta_1*j-k),2)
            J.append(total/(2*m))
        theta=range(len(J)) #长度;
        print('最后一的斜率*学习率是:{}'.format(temp_0)) #最后一个的值
        print('最后一的值是:{} {}'.format(theta_0,theta_1))
        x=np.linspace(4,24,200)
        y=h(theta_0,theta_1,x)
        plt.plot(x,y,label='linear regression')
        plt.xlabel('theta')
        plt.ylabel('J(theta)')
        plt.legend()
plt.show() 

在这里插入图片描述

  • 合并起来的效果是:
    在这里插入图片描述

多元线性回归

  • 多元线性回归只需要将方程改为:

y = a + b × θ 0 + c × θ 1 y=a+b \times \theta_0+c \times \theta_1 y=a+b×θ0+c×θ1

从而按上面的步骤来一步一步的做;

# -*- coding: utf-8 -*-
"""
Created on Fri May 12 14:27:25 2023

@author: windows
"""
import matplotlib.pyplot as plt
import numpy as np
def h(a,b,c,x,y):
    return a+b*x+c*y
def normoalize_feature(data): #归一化;
    data1=[]
    for data2 in data:
     data1.append((data2-data2.mean())/data2.std())
    return data1
def get_data(data):
    with open(data,'r') as f:
        lines=f.readlines()
    size,num,price=[],[],[] #房子的大小,卧室的数量,房子的价格;
    for line in lines:
        data1=(line.strip().split(','))
        size.append(float(data1[0]))     #缩小特征值;km^2
        num.append(float(data1[1]))
        price.append(float(data1[2]))  
    size,num,price=np.array(size),np.array(num),np.array(price)
    h=[size,num,price]
    return h
size,num,price=normoalize_feature(get_data('ex1data2.txt'))
print(size,num,price)
theta_0,theta_1,theta_2=0,0,0
# 3D
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(num,size,price,marker='+',c='red',cmap='coolwarm')
ax.set_xlabel('num')
ax.set_ylabel('size')
ax.set_zlabel('price')

rate=0.012
iteration=100
i=0
J=[] #J(theta)的值;
total=0
m=len(size)
while i<iteration:
    #求theta;
    temp_0,temp_1,temp_2=0,0,0
    temp_0=sum((theta_0+theta_1*size+theta_2*num-price)) #theta_0
    temp_1=sum(size.T*(theta_0+theta_1*size+theta_2*num-price)) #theta_1;
    temp_2=sum(num.T*(theta_0+theta_1*size+theta_2*num-price))
    theta_0,theta_1,theta_2=theta_0-rate*temp_0,theta_1-rate*temp_1,theta_2-rate*temp_2
    total=(theta_0+theta_1*size+theta_2*num-price).T*(theta_0+theta_1*size+theta_2*num-price)
    J.append(sum(total)/2*m)
    i+=1 
print(theta_0,theta_1,theta_2)
plt.subplot(projection='3d')
x=np.linspace(-3,3,200)
y=np.linspace(-3,3,200)
z=h(theta_0,theta_1,theta_2,x,y)
plt.plot(x,y,z)
plt.show()

效果就是:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/548105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何让 300 万程序员爱上 CODING?

**《DNSPod十问》**是由腾讯云企业中心推出的一档深度谈话栏目&#xff0c;通过每期向嘉宾提出十个问题&#xff0c;带着广大读者站在产业互联网、科技领域精英的肩膀上&#xff0c;俯瞰各大行业发展趋势和前沿技术革新。 刘毅&#xff0c;腾讯云 CODING CEO、腾讯云开发者产品…

第十六章_Redis案例落地实战bitmap/hyperloglog/GEO

统计的类型有哪些 亿级系统中常见的四种统计 聚合统计 统计多个集合元素的聚合结果&#xff0c;就是前面讲解过的交差并等集合统计 复习命令 交并差集和聚合函数的应用 排序统计 抖音短视频最新评论留言的场景&#xff0c;请你设计一个展现列表。考察你的数据结构和设计思…

Nsudo,建议有编程基础的人使用,获取管理员和超级管理员权限

资源地址&#xff1a; https://download.csdn.net/download/yaosichengalpha/87801699 Nsudo,建议有编程基础的人使用&#xff0c;获取管理员和超级管理员权限 NSudo是一款非常不错的系统管理工具&#xff0c;他是基于raymai97的超级命令提示符&#xff0c;可以帮助我们获取T…

MybatisPlus--基础入门!真滴方便

目录 一、简介 2.特性 二、入门 1.创建springboot 项目 注意&#xff1a;引入 MyBatis-Plus 之后请不要再次引入 MyBatis 以及 MyBatis-Spring&#xff0c;以避免因版本差异导致的问题 2.数据准备 3.配置application.yml 4.代码 BaseMapper<>很重要&#xff01;…

vue 本地/PC端访问微信云数据库

1. 解决跨域访问问题 新建文件 vue.config.js // 后端服务器地址 let url "http://localhost:8888"; module.exports {publicPath: "./", // 【必要】静态文件使用相对路径outputDir: "./dist", //打包后的文件夹名字及路径devServer: {// 开…

组合数学第二讲

可以把取出来的数从小到大排序&#xff0c;第一个数不变&#xff0c;第二个数1&#xff0c;以此类推... 总共的情况为&#xff0c;数字取完后可再依次减回去&#xff0c;保证数在100以内 k-element multisets 引出下面的二项式系数 binomial coefficients&#xff08;二项式系…

线段树C++实现

一、本题线段树数组数据和结构 data[]{1,2,-3,5,6,-2,7,1,12,30,-10}&#xff0c;11个元素。 二、各个函数和结构 &#xff08;一&#xff09;线段树结构 创建线段树的结构&#xff0c; l、r为左边界和右边界&#xff0c;maxV和minV为最大值和最小值&#xff0c;sum为和&#…

English Learning - L3 作业打卡 Lesson2 Day12 2023.5.16 周二

English Learning - L3 作业打卡 Lesson2 Day12 2023.5.16 周二 引言&#x1f349;句1: Dollars are called greenbacks because that is the color of the back side of the paper money.成分划分弱读连读爆破语调 &#x1f349;句2: The color black is used often in expres…

抽象 + 接口 + 内部类

抽象类和抽象方法 抽象类不能实例化抽象类不一定有抽象方法&#xff0c;有抽象方法的类一定是抽象方法可以有构造方法抽象类的子类 要么重写抽象类中的所有抽象方法要么是抽象类 案例 Animal类Dog类 Sheep类Test类 接口 接口抽象类针对事物&#xff0c;接口针对行为案…

使用Google浏览器开启New bing

简介 搭建 通过谷歌商店下载两个浏览器插件&#xff0c;一个用于修改请求头agent的插件和一个用于伪造来源的插件x-forwarded-for插件&#xff0c;当然类似的插件很多很多&#xff0c;我这里使用的两个插件是 User-Agent Switcher Header Editor 使用 User-Agent Switcher 插件…

云HIS住院业务模块常见问题及解决方案

一&#xff1a;住院业务 1.患者办理住院时分配了错误的病区怎么办&#xff1f; 操作员误操作将患者分配了错误的病区分为以下两种情况&#xff1a; &#xff08;1&#xff09;、患者刚刚入院&#xff0c;未分配床位、主治医师与管床护士&#xff1a;这种情况比较好处理&#xf…

文件转pdf

背景 项目中很多预览工具&#xff0c;文件转pdf预览&#xff0c;采用libreoffice6.1插件实现 环境说明 系统CentOS&#xff1a;CentOS7 libreoffice&#xff1a;6.1 下载 中文官网 https://zh-cn.libreoffice.org/download/libreoffice/ 下载其他老版本 Index of /lib…

不敢妄谈K12教育

做为大学生的父亲&#xff1a;不敢妄谈孩子教育 大约10年前&#xff0c;写了一本教育书稿 找到一个出版社的编辑&#xff0c;被训了一通 打消了出书以及K12教育的念想 趣讲大白话&#xff1a;娘生九子&#xff0c;各有不同 【趣讲信息科技171期】 ****************************…

Vs+Qt+C++电梯调度控制系统

程序示例精选 VsQtC电梯调度控制系统 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<VsQtC电梯调度控制系统>>编写代码&#xff0c;代码整洁&#xff0c;规则&#xff0c;易读。…

PT100温度采集

1、信号采集的基本原理 PT100是将温度信号转换为电阻输出&#xff0c;其电阻值变化范围为0~200Ω。AD转换器只能对电压进行转换&#xff0c;无法采集直接采集温度&#xff0c;因此&#xff0c;需要一个1mA恒电流源给PT100供电&#xff0c;将电阻变化转换为电压变化。使用恒流源…

linux 安装 maven 3.8 版本

文章目录 1&#xff1a;maven 仓库官网 2、下载安装包 3、使用&#xff1a;Xftp 上传到你想放的目录 4、解压文件 ​编辑 5、配置环境变量 ​编辑 6、刷新 /etc/profile 文件 7、查看maven 版本 1&#xff1a;maven 仓库官网 Maven – Download Apache Mavenhttps://mave…

【C++】模板的一点简单介绍

模板 前言泛型编程函数模板概念格式函数模板的原理函数模板的实例化 类模板类模板的定义格式类模板的实例化 前言 这篇博客讲的是模板的一些基本知识&#xff0c;并没有那么深入&#xff0c;但是如果你是为了过期末考试而搜的这篇博客&#xff0c;我觉得下面讲的是够了的。 之…

简单分享线程池的设计

温故而知新&#xff0c;可以为师矣。 线程池是什么 线程池&#xff08;Thread Pool&#xff09;是一种基于池化思想管理线程的工具&#xff0c;经常出现在多线程服务器中&#xff0c;如MySQL。 池化思想&#xff0c;就是为了提高对资源的利用率&#xff0c;减少对资源的管理&a…

MySQL---空间索引、验证索引、索引特点、索引原理

1. 空间索引 MySQL在5.7之后的版本支持了空间索引&#xff0c;而且支持OpenGIS几何数据模型 空间索引是对空间数据类型的字段建立的索引&#xff0c;MYSQL中的空间数据类型有4种&#xff0c;分别是&#xff1a; 类型 含义 说明 Geometry 空间数据 任何一种空间类型 Poi…

HCIA-VRP系统

目录 一&#xff0c;什么是VRP VRP提供的功能&#xff1a; VRP文件系统&#xff1a; VRP存储设备&#xff1a; 设备初始化过程&#xff1a; 设备管理方式&#xff1a; 1&#xff0c;Web界面&#xff1a;可视化操作&#xff0c;通过http和https登录&#xff08;192.168.1.…