机器学习模型——回归模型

news2024/11/17 23:32:35

文章目录

  • 监督学习——回归模型
  • 线性回归模型
    • 最小二乘法
    • 求解线性回归
    • 代码实现
        • 引入依赖:
        • 导入数据:
        • 定义损失函数:
        • 定义核心算法拟合函数:
        • 测试:
        • 画出拟合曲线:
  • 多元线性回归
    • 梯度下降求线性回归
    • 梯度下降和最小二乘法
    • 代码实现
        • 定义模型的超参数:
        • 定义核心梯度下降算法函数:
        • 测试:
        • 画出拟合曲线:
    • 调用sklearn库代码实现
        • 调用库:
        • 获取数值:
        • 画出拟合曲线:

这里来学习一下机器学习的一些模型,包括:

  • 监督学习:
    • 回归模型:
      • 线性回归
    • 分类模型:
      • k近邻(kNN)
      • 决策树
      • 逻辑斯谛回归
  • 无监督学习:
    • 聚类:
      • k均值(k-means)
    • 降维

监督学习——回归模型

线性回归模型

  • 线性回归(linear regression)是一种线性模型,它假设输入变量x和单个输出变量y之间存在线性关系。
  • 具体来说,利用线性回归模型,可以从一组输入变量x的线性组合中,计算输出变量y

在这里插入图片描述

在这里插入图片描述

  • 线性回归模型:

在这里插入图片描述

定义

在这里插入图片描述

最小二乘法

在这里插入图片描述

在这里插入图片描述

求解线性回归

在这里插入图片描述

代码实现

引入依赖:

import numpy as np
import matplotlib.pyplot as plt

导入数据:

points = np.genfromtxt('data.csv', delimiter=',')

# points[0,0] 第一行第一列的元素

# 提取points中的两列数据,分别作为x,y
x = points[:, 0] # 所有行的第一列元素
y = points[:, 1] # 所有行的第二列元素

# 用plt画出散点图
plt.scatter(x, y)
plt.show()

在这里插入图片描述

其中数据data.csv:

32.502345269453031,31.70700584656992
53.426804033275019,68.77759598163891
61.530358025636438,62.562382297945803
47.475639634786098,71.546632233567777
59.813207869512318,87.230925133687393
55.142188413943821,78.211518270799232
52.211796692214001,79.64197304980874
39.299566694317065,59.171489321869508
48.10504169176825,75.331242297063056
52.550014442733818,71.300879886850353
45.419730144973755,55.165677145959123
54.351634881228918,82.478846757497919
44.164049496773352,62.008923245725825
58.16847071685779,75.392870425994957
56.727208057096611,81.43619215887864
48.955888566093719,60.723602440673965
44.687196231480904,82.892503731453715
60.297326851333466,97.379896862166078
45.618643772955828,48.847153317355072
38.816817537445637,56.877213186268506
66.189816606752601,83.878564664602763
65.41605174513407,118.59121730252249
47.48120860786787,57.251819462268969
41.57564261748702,51.391744079832307
51.84518690563943,75.380651665312357
59.370822011089523,74.765564032151374
57.31000343834809,95.455052922574737
63.615561251453308,95.229366017555307
46.737619407976972,79.052406169565586
50.556760148547767,83.432071421323712
52.223996085553047,63.358790317497878
35.567830047746632,41.412885303700563
42.436476944055642,76.617341280074044
58.16454011019286,96.769566426108199
57.504447615341789,74.084130116602523
45.440530725319981,66.588144414228594
61.89622268029126,77.768482417793024
33.093831736163963,50.719588912312084
36.436009511386871,62.124570818071781
37.675654860850742,60.810246649902211
44.555608383275356,52.682983366387781
43.318282631865721,58.569824717692867
50.073145632289034,82.905981485070512
43.870612645218372,61.424709804339123
62.997480747553091,115.24415280079529
32.669043763467187,45.570588823376085
40.166899008703702,54.084054796223612
53.575077531673656,87.994452758110413
33.864214971778239,52.725494375900425
64.707138666121296,93.576118692658241
38.119824026822805,80.166275447370964
44.502538064645101,65.101711570560326
40.599538384552318,65.562301260400375
41.720676356341293,65.280886920822823
51.088634678336796,73.434641546324301
55.078095904923202,71.13972785861894
41.377726534895203,79.102829683549857
62.494697427269791,86.520538440347153
49.203887540826003,84.742697807826218
41.102685187349664,59.358850248624933
41.182016105169822,61.684037524833627
50.186389494880601,69.847604158249183
52.378446219236217,86.098291205774103
50.135485486286122,59.108839267699643
33.644706006191782,69.89968164362763
39.557901222906828,44.862490711164398
56.130388816875467,85.498067778840223
57.362052133238237,95.536686846467219
60.269214393997906,70.251934419771587
35.678093889410732,52.721734964774988
31.588116998132829,50.392670135079896
53.66093226167304,63.642398775657753
46.682228649471917,72.247251068662365
43.107820219102464,57.812512976181402
70.34607561504933,104.25710158543822
44.492855880854073,86.642020318822006
57.50453330326841,91.486778000110135
36.930076609191808,55.231660886212836
55.805733357942742,79.550436678507609
38.954769073377065,44.847124242467601
56.901214702247074,80.207523139682763
56.868900661384046,83.14274979204346
34.33312470421609,55.723489260543914
59.04974121466681,77.634182511677864
57.788223993230673,99.051414841748269
54.282328705967409,79.120646274680027
51.088719898979143,69.588897851118475
50.282836348230731,69.510503311494389
44.211741752090113,73.687564318317285
38.005488008060688,61.366904537240131
32.940479942618296,67.170655768995118
53.691639571070056,85.668203145001542
68.76573426962166,114.85387123391394
46.230966498310252,90.123572069967423
68.319360818255362,97.919821035242848
50.030174340312143,81.536990783015028
49.239765342753763,72.111832469615663
50.039575939875988,85.232007342325673
48.149858891028863,66.224957888054632
25.128484647772304,53.454394214850524

定义损失函数:

# 损失函数是系数的函数,另外还要传入数据的x,y
def compute_cost(w, b, points):
    total_cost = 0
    M = len(points)
    
    # 逐点计算平方损失误差,然后求平均数
    for i in range(M):
        x = points[i, 0]
        y = points[i, 1]
        total_cost += ( y - w * x - b ) ** 2 # **2 代表平方
    
    return total_cost / M

定义核心算法拟合函数:

# 先定义一个求均值的函数
def average(data):
    sum = 0
    num = len(data)
    for i in range(num):
        sum += data[i]
    return sum/num

# 定义核心拟合函数
def fit(points):
    M = len(points)
    x_bar = average(points[:, 0])
    
    sum_yx = 0
    sum_x2 = 0
    sum_delta = 0
    
    for i in range(M):
        x = points[i, 0]
        y = points[i, 1]
        sum_yx += y * ( x - x_bar )
        sum_x2 += x ** 2
    # 根据公式计算w
    w = sum_yx / ( sum_x2 - M * (x_bar**2) )
    
    for i in range(M):
        x = points[i, 0]
        y = points[i, 1]
        sum_delta += ( y - w * x )
    b = sum_delta / M
    
    return w, b

测试:

w, b = fit(points)

print("w is: ", w) # 斜率
print("b is: ", b) 

cost = compute_cost(w, b, points)

print("cost is: ", cost)
w is:  1.3224310227553846
b is:  7.991020982269173
cost is:  110.25738346621313

画出拟合曲线:

plt.scatter(x, y)
# 针对每一个x,计算出预测的y值
pred_y = w * x + b

plt.plot(x, pred_y, c='r')
plt.show()

在这里插入图片描述

多元线性回归

在这里插入图片描述

梯度下降求线性回归

在这里插入图片描述

在这里插入图片描述

梯度下降和最小二乘法

  • 相同点
    • 本质和目标相同:两种方法都是经典的学习算法,在给定已知数据的前提下利用求导算出一个模型(函数),使得损失函数最小,然后对给定的新数据进行估算预测
  • 不同点
    • 损失函数:梯度下降可以选取其他损失函数,而最小二乘一定是平方损失函数
    • 实现方法:最小二乘法是直接求导出全局最小;而梯度下降是一种迭代法。
    • 效果:最小二乘找到的一定是全局最小,但计算繁琐,且复杂情况下未必有解;梯度下降迭代计算简单,但找到的一般是局部最小,只有目标函数式凸函数时才是全局最小;到最小点附近时收敛速度会变慢,且对初始点的选择极为敏感

代码实现

引入依赖、导入数据、定义损失函数同上。

定义模型的超参数:

alpha = 0.0001 # 步长
initial_w = 0
initial_b = 0
num_iter = 10 # 迭代次数

定义核心梯度下降算法函数:

def grad_desc(points, initial_w, initial_b, alpha, num_iter):
    w = initial_w
    b = initial_b
    # 定义一个list保存所有的损失函数值,用来显示下降的过程
    cost_list = []
    
    for i in range(num_iter):
        cost_list.append( compute_cost(w, b, points) )
        w, b = step_grad_desc( w, b, alpha, points )
    
    return [w, b, cost_list]

# 迭代
def step_grad_desc( current_w, current_b, alpha, points ):
    sum_grad_w = 0
    sum_grad_b = 0
    M = len(points)
    
    # 对每个点,代入公式求和
    for i in range(M):
        x = points[i, 0]
        y = points[i, 1]
        sum_grad_w += ( current_w * x + current_b - y ) * x
        sum_grad_b += current_w * x + current_b - y
    
    # 用公式求当前梯度
    grad_w = 2/M * sum_grad_w
    grad_b = 2/M * sum_grad_b
    
    # 梯度下降,更新当前的w和b
    updated_w = current_w - alpha * grad_w
    updated_b = current_b - alpha * grad_b
    
    return updated_w, updated_b

测试:

w, b, cost_list = grad_desc( points, initial_w, initial_b, alpha, num_iter )

print("w is: ", w)
print("b is: ", b)

cost = compute_cost(w, b, points)

print("cost is: ", cost)

plt.plot(cost_list)
plt.show()

在这里插入图片描述

画出拟合曲线:

plt.scatter(x, y)
# 针对每一个x,计算出预测的y值
pred_y = w * x + b

plt.plot(x, pred_y, c='r')
plt.show()

在这里插入图片描述

调用sklearn库代码实现

引入依赖、导入数据、定义损失函数同上。

调用库:

from sklearn.linear_model import LinearRegression
lr = LinearRegression()

x_new = x.reshape(-1, 1) # reshape(-1, 1) -1表示行数不限,1表示列数为1,即变为n行1列矩阵
y_new = y.reshape(-1, 1)
lr.fit(x_new, y_new)

获取数值:

# 从训练好的模型中提取系数和截距
w = lr.coef_[0][0] # 系数
b = lr.intercept_[0] # 截距

print("w is: ", w)
print("b is: ", b)

cost = compute_cost(w, b, points)

print("cost is: ", cost)
w is:  1.3224310227553597
b is:  7.991020982270399
cost is:  110.25738346621318

画出拟合曲线:

plt.scatter(x, y)
# 针对每一个x,计算出预测的y值
pred_y = w * x + b

plt.plot(x, pred_y, c='r')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/570276.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

chatgpt赋能python:Python中%取模操作的介绍

Python中%取模操作的介绍 在Python中,取模操作使用符号“%”表示,它的作用是取两个数相除的余数。例如,10 % 3等于1,因为10除以3的余数为1。这个操作可以用在很多场合,比如判断一个数是奇数还是偶数,或者判…

带你开发一个远程控制项目---->STM32+标准库+阿里云平台+传感器模块+远程显示。

目录 本次实验项目: 下次实验项目: 本次项目视频结果/APP/实物展示 实物展示 APP展示 视频展示 模块选择说明; 温湿度传感器模块介绍 光照传感器介绍 ESP8266-01S模块介绍 本次实验项目: 项目清单平台单片机语言实现温湿度传感器模…

Reinforcement Learning | 强化学习十种应用场景及新手学习入门教程

文章目录 1.在自动驾驶汽车中的应用2.强化学习的行业自动化3.强化学习在贸易和金融中的应用4.NLP(自然语言处理)中的强化学习5.强化学习在医疗保健中的应用6.强化学习在工程中的应用7.新闻推荐中的强化学习8.游戏中的强化学习9.实时出价——强化学习在营…

Redis中的Reactor模型源码探索

文章目录 摘要了解Linux的epoll了解Reactor模型 源码initServerinitListenersaeMain 事件管理器aeProcessEvents读事件 摘要 有时候在面试的时候会被问到Redis为什么那么快?有一点就是客户端请求和应答是基于I/O多路复用(比如linux的epoll)的…

【高级语言程序设计(一)】第 9 章:编译预处理命令

目录 前言 一、宏定义命令 (1)无参宏定义 (2)有参宏定义 ① 带参数的宏定义 ② 带参宏定义与函数的区别 二、文件包含命 (1)文件包含命令的定义 (2)文件包含命令的格式 &…

【Leetcode60天带刷】day02—— 977.有序数组的平方、209.长度最小的子数组、 59.螺旋矩阵II

题目:997.有序数组的平方 Leetcode原题链接:997.有序数组的平方——力扣 思考历程与知识点: 题目的意思很简单,就是把每个数的平方,按从小到大的顺序排个序,再输出出来。 第一想法是先每个数平方一遍&a…

Stream API的使用

使用Stream API对集合中的数据进行操作,就类似使用SQL语句对数据库执行查询 Stream不会存储数据Stream不会改变源对象,而是返回一个持有结果的新StreamStream是延迟执行的,只有在需要结果的时候才执行,即只有执行终止操作&#xf…

离散数学_十章-图 ( 2 ):图的术语和几种特殊的图

📷10.2 图的术语和几种特殊的图 1. 基本术语1.1 邻接(相邻)1.2 邻居1.3 顶点的度1.4 孤立点1.5 悬挂点例题 2. 握手定理3. 握手定理的推论4. 带有有向边的图的术语4.1 邻接4.2 度——出度 和 入度4.3 例题: 5. 定理:入…

PHP 反序列化漏洞

PHP反序列化漏洞在实际测试中出现的频率并不高,主要常出现在CTF中。 PHP序列化概述 PHP序列化函数: serialize:将PHP的数据,数组,对象等序列化为字符串unserialize:将序列化后的字符串反序列化为数据&…

chatgpt赋能python:Python单词库的重要性

Python单词库的重要性 Python是一种高级编程语言,被广泛用于应用程序开发、网络编程、数据科学和人工智能开发等领域。而在Python编程中,单词库(或词典)的重要性不言而喻。单词库就是存放Python程序中经常使用的关键字、方法名、函数名等词汇的地方。本…

SpringBoot --- 实用篇

一、热部署 1.1、概念 什么是热部署?简单说就是你程序改了,现在要重新启动服务器,嫌麻烦?不用重启,服务器会自己悄悄的把更新后的程序给重新加载一遍,这就是热部署。 ​ 热部署的功能是如何实现的呢&…

谷歌浏览器被2345劫持

方法1: 打开控制面板的卸载程序,搜索2345,把那个恶心的“安全组件-2345”卸载掉!! 这个方法比修改 host 以及注册表要好使地多! 参考网址: 【小技巧】修复chrome被2345劫持 方法2: …

Alma Linux 9.2、Rocky Linux 9.2现在是RHEL 9.2的替代品

随着Red Hat Enterprise Linux (RHEL) 9.2的发布,Alma Linux 9.2和Rocky Linux 9.2成为了RHEL 9.2的备选替代品。这两个Linux发行版旨在提供与RHEL兼容的功能和稳定性,以满足那些需要企业级操作系统的用户需求。本文将详细介绍Alma Linux 9.2和Rocky Lin…

nginx反向代理缓存

背景 nginx 一般用来做反向代理和负载均衡,将客户端请求发送到后端的 jetty,并将 jetty 的响应发送给客户端。后端的 jetty 通常不止一个,nginx 根据配置来选择其中一个 jetty,比较常见的选择策略是轮询。示意图如下 启动缓存支…

oracle19c介绍和安装

目录 一、版本 (1)历史 (2)11g和12c管理方式区别 11g 12C (3)各个版本对操作系统要求 二、分类 (1)分为桌面类和服务器类 (2)分为企业版和标准版 三…

基于遗传算法的BP神经网络优化算法(matlab实现)

1 理论基础 1.1 BP神经网络概述 BP网络是一类多层的前馈神经网络。它的名字源于在网络训练的过程中,调整网络的权值的算法是误差的反向传播的学习算法,即为BP学习算法。BP算法是Rumelhart等人在1986年提出来的。由于它的结构简单,可调整的…

个人网站实现微信扫码登录

⭐个人网站实现微信扫码登录 🥈效果图 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kzSrNgiv-1685034480658)(https://img.ggball.top/picGo/动画.gif)] 📗开发背景 为什么想用微信扫码登录呢? 起因是自己开发…

【CH32】| 02——常用外设 | GPIO

系列文章目录 【CH32】| 00——开发环境搭建 【CH32】| 01——新建工程 | 下载 | 运行 |调试 【CH32】| 02——常用外设 | GPIO 失败了也挺可爱,成功了就超帅。 文章目录 前言1. GPIO简介2. IO口的内部结构框图保护二极管上下拉电阻施密特触发器两个MOS管输出寄存器…

chatgpt赋能python:Python加速循环的执行方法详解

Python 加速循环的执行方法详解 Python是一门非常流行的编程语言,它可以在很多领域应用,比如Web开发、数据分析、机器学习等等。然而,Python执行速度较慢,特别是在循环语句中,代码执行效率会大打折扣。在本文中&#…

【基于ROS Melodic环境安装rosserial arduino】

【基于ROS Melodic环境安装rosserial arduino】 1. 简介2. 安装2.1 Ubuntu下的Arduino IDE安装2.2 Ubuntu下rosserial arduino软件安装2.3 安装ros_lib到Arduino IDE开发环境 3. 将ros_lib配置到 Arduino 环境库中4. 使用helloword5. 实验验证6.总结 1. 简介 这个教程展示如何…