【机器学习实战】七、梯度下降

news2024/12/28 19:33:08

梯度下降

一、线性回归

线性回归算法推导过程可以基于最小二乘法直接求解,但这并不是机器学习的思想,由此引入了梯度下降方法。本文讲解其中每一步流程与实验对比分析。

1.初始化
import numpy as np
import os
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings('ignore')
np.random.seed(42)
2.回归方程

在这里插入图片描述

import numpy as np
X = 2*np.random.rand(100,1)
y = 4+ 3*X +np.random.randn(100,1)
plt.plot(X,y,'b.')
plt.xlabel('X_1')
plt.ylabel('y')
plt.axis([0,2,0,15])
plt.show()

在这里插入图片描述

X_b = np.c_[np.ones((100,1)),X]
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
print(theta_best) 
# 输出 :
array([[4.21509616],
       [2.77011339]])
X_new = np.array([[0],[2]])
X_new_b = np.c_[np.ones((2,1)),X_new]
y_predict = X_new_b.dot(theta_best)
print(y_predict)
# 输出:
array([[4.21509616],
       [9.75532293]])
plt.plot(X_new,y_predict,'r--')
plt.plot(X,y,'b.')
plt.axis([0,2,0,15])
plt.show()

在这里插入图片描述

二、调用sklearn API

sklearnAPI官网: https://scikit-learn.org/stable/modules/classes.html

from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X,y)
print (lin_reg.coef_)
print (lin_reg.intercept_)
# 
[[2.77011339]]
[4.21509616]

三、梯度下降

在这里插入图片描述
当步长较小时,训练次数较多;
在这里插入图片描述
当步长较大时,波动大;
在这里插入图片描述
学习率应当尽可能小,随着迭代的进行应当越来越小。
在这里插入图片描述
在这里插入图片描述

1.批量梯度下降
eta = 0.1 #学习率
n_iterations = 1000 # 迭代次数
m = 100
theta = np.random.randn(2,1) # 随机初始化参数theta
for iteration in range(n_iterations):
    gradients = 2/m* X_b.T.dot(X_b.dot(theta)-y)
    theta = theta - eta*gradients
theta
# 
array([[4.21509616],
       [2.77011339]])
X_new_b.dot(theta)
#
array([[4.21509616],
       [9.75532293]])
theta_path_bgd = []
def plot_gradient_descent(theta,eta,theta_path = None):
    m = len(X_b)
    plt.plot(X,y,'b.')
    n_iterations = 1000
    for iteration in range(n_iterations):
        y_predict = X_new_b.dot(theta)
        plt.plot(X_new,y_predict,'b-')
        gradients = 2/m* X_b.T.dot(X_b.dot(theta)-y)
        theta = theta - eta*gradients
        if theta_path is not None:
            theta_path.append(theta)
    plt.xlabel('X_1')
    plt.axis([0,2,0,15])
    plt.title('eta = {}'.format(eta))
theta = np.random.randn(2,1)

plt.figure(figsize=(10,4))
plt.subplot(131)
plot_gradient_descent(theta,eta = 0.02)
plt.subplot(132)
plot_gradient_descent(theta,eta = 0.1,theta_path=theta_path_bgd)
plt.subplot(133)
plot_gradient_descent(theta,eta = 0.5)
plt.show()

在这里插入图片描述

2.随机梯度下降

在这里插入图片描述

theta_path_sgd=[]
m = len(X_b)
np.random.seed(42)
n_epochs = 50
t0 = 5
t1 = 50

def learning_schedule(t):
    return t0/(t1+t)
theta = np.random.randn(2,1)

for epoch in range(n_epochs):
    for i in range(m):
        if epoch < 10 and i<10:
            y_predict = X_new_b.dot(theta)
            plt.plot(X_new,y_predict,'r-')
        random_index = np.random.randint(m)
        xi = X_b[random_index:random_index+1]
        yi = y[random_index:random_index+1]
        gradients = 2* xi.T.dot(xi.dot(theta)-yi)
        eta = learning_schedule(epoch*m+i)
        theta = theta-eta*gradients
        theta_path_sgd.append(theta)
        
plt.plot(X,y,'b.')
plt.axis([0,2,0,15])   
plt.show()

在这里插入图片描述

3.MiniBatch梯度下降
theta_path_mgd=[]
n_epochs = 50
minibatch = 16
theta = np.random.randn(2,1)
t0, t1 = 200, 1000
def learning_schedule(t):
    return t0 / (t + t1)
np.random.seed(42)
t = 0
for epoch in range(n_epochs):
    shuffled_indices = np.random.permutation(m)
    X_b_shuffled = X_b[shuffled_indices]
    y_shuffled = y[shuffled_indices]
    for i in range(0,m,minibatch):
        t+=1
        xi = X_b_shuffled[i:i+minibatch]
        yi = y_shuffled[i:i+minibatch]
        gradients = 2/minibatch* xi.T.dot(xi.dot(theta)-yi)
        eta = learning_schedule(t)
        theta = theta-eta*gradients
        theta_path_mgd.append(theta)
theta 
# 
array([[4.25490684],
       [2.80388785]])

四、3种策略的对比实验

theta_path_bgd = np.array(theta_path_bgd)
theta_path_sgd = np.array(theta_path_sgd)
theta_path_mgd = np.array(theta_path_mgd)
plt.figure(figsize=(12,6))
plt.plot(theta_path_sgd[:,0],theta_path_sgd[:,1],'r-s',linewidth=1,label='SGD')
plt.plot(theta_path_mgd[:,0],theta_path_mgd[:,1],'g-+',linewidth=2,label='MINIGD')
plt.plot(theta_path_bgd[:,0],theta_path_bgd[:,1],'b-o',linewidth=3,label='BGD')
plt.legend(loc='upper left')
plt.axis([3.5,4.5,2.0,4.0])
plt.show()

在这里插入图片描述
实际当中用minibatch比较多,一般情况下选择batch数量应当越大越好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/343939.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言(结构和指针)

目录 1.声明结构指针 2.用指针访问成员 3.传递结构成员 4.传递结构的地址 5.传递结构 6.机构的其他特性 7.结构中的字符数组和字符指针 关于为什么要使用指向结构的指针。 第一&#xff0c;就像指向数组的指针比数组本身更容易操作一样&#xff0c;指向结构的指针通常比…

5年自动化测试,终于进字节了,年薪30w其实也并非触不可及

我的职业生涯开始和大多数测试人一样&#xff0c;开始接触都是纯功能界面测试&#xff0c;第一份测试工作就是在电商公司做功能测试&#xff0c;工作忙忙碌碌&#xff0c;每天在各种业务需求学习和点点中度过&#xff0c;过了好几年发现自己还只是一个功能测试工程师&#xff0…

第十二章 Ambari二次开发之集成Alluxio

1、Alluxio高可用部署 生产环境&#xff1a;使用具有高可用性的模式来运行Alluxio masters。 1.1、Alluxio架构 ​ Alluxio可以被分为三个部分&#xff1a;**masters、workers以及clients。**一个典型的设置由一个主服务器、多个备用服务器和多个worker组成。客户端用于通过S…

机器学习实战--梯度下降法进行波士顿房价预测

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 今天来学习一下如何使用机器学习梯度下降法进行波士顿房价预测&#xff0c;这是简单的一个demo&#xff0c;主要展示的是一些小小的思路~ 本文目录&#xff1a;一、波士顿房价预测1.全部的数据可视化2.地理数据可视化3.房…

基于”PLUS模型+“生态系统服务多情景模拟预测实践

工业革命以来&#xff0c;社会生产力迅速提高&#xff0c;人类活动频繁&#xff0c;此外人口与日俱增对土地的需求与改造更加强烈&#xff0c;人-地关系日益紧张。此外&#xff0c;土地资源的不合理开发利用更是造成了水土流失、植被退化、水资源短缺、区域气候变化、生物多样性…

根据手机号显示其运营商信息phone.find

【小白从小学Python、C、Java】【计算机等级考试500强双证书】【Python-数据分析】根据手机号显示其运营商信息phone.find选择题以下关于python代码表述错误的一项是?from phone import PhonephonePhone()print(【执行】phone.find())resultphone.find("13366667777"…

21.操作符优先级和结合性列表,复杂表达式求值顺序

目录一、复杂表达式求值顺序1.操作符的优先级2.操作符的结合性3.操作符是否控制执行的顺序二、求值顺序三、操作符优先级和结合性列表一、复杂表达式求值顺序 复杂表达式的求值顺序由三个因素决定&#xff1a; 1.操作符的优先级 2.操作符的结合性 3.操作符是否控制执行的顺序 1…

【代码随想录训练营】【Day12休息】【Day13】第五章|栈与队列|239. 滑动窗口最大值|347.前 K 个高频元素|总结

239.滑动窗口最大值 题目详细&#xff1a;LeetCode.239 看到滑动窗口&#xff0c;我立马想起了双指针&#xff0c;利用双指针可以非常清晰地理解解题思路&#xff1a; 定义一个变量 max_i 用于记录窗口中的最大值的索引每次窗口滑动后 如果出去的值是最大值&#xff0c;那么…

ChatGPT实火,这小东西牛在哪?

ChatGPT&#xff0c;真的火了啊&#xff01; 相信许多朋友都听说过 ChatGPT铺天盖地的赞美&#xff0c;但并不清楚它是个啥。 体制内让ChatGPT写材料&#xff0c;广告行业让ChatGPT写策划案&#xff0c;媒体让ChatGPT写新闻稿&#xff0c;程序员让ChatGPT写代码甚至还带修BUG服…

三、常用样式讲解一

文章目录一、企业站点样式实战1.1 版心1.2 reset.css1.3 index.css&#xff08;首页的样式&#xff09;1.4 溢出1.5 元素类型1.6 元素类型的转换1.7 行内块元素的特殊情况&#xff1a;img标签的特殊性一、企业站点样式实战 1.1 版心 1.2 reset.css /* reset.css用作清除一些常…

行人检测(人体检测)2:YOLOv5实现人体检测(含人体检测数据集和训练代码)

行人检测(人体检测)2&#xff1a;YOLOv5实现人体检测(含人体检测数据集和训练代码) 目录 行人检测(人体检测)2&#xff1a;YOLOv5实现人体检测(含人体检测数据集和训练代码) 1. 前言 2. 人体检测数据集说明 &#xff08;1&#xff09;人体检测数据集 &#xff08;2&#…

什么是互联网舆情监测分析系统,TOOM舆情监测云服务有哪些内容?

舆情监测应用范围广泛&#xff0c;可以帮助企业了解品牌形象、产品口碑、市场竞争、消费者需求等信息&#xff0c;政府了解民意状况、政策反响、社会热点等信息&#xff0c;个人了解社会趋势、舆论氛围、公共事件等信息。同时&#xff0c;舆情监测分析也可以帮助相关决策者及时…

男生vs女生,谁更加适合做软件测试?

前言 随着互联网的飞速发展&#xff0c;软件测试行业同步兴盛起来&#xff0c;逐渐出现了人才的短缺&#xff0c;致使行业人员工资一涨再涨。 所以&#xff0c;越来越多的人也开始意识到软件测试行业的”高薪“属性&#xff0c;转身投入到相关的工作中来。 但是&#xff0c;…

【Spring Cloud】如何把Feign默认的HTTP客户端URLConnection更换成支持连接池的Apache HttpClient或OKHttp

本期目录前言1. Feign底层的客户端实现2. Feign性能优化思路3. 更换底层客户端1&#xff09;引入依赖坐标2&#xff09;配置连接池前言 本次示例代码的文件结构如下图所示。 1. Feign底层的客户端实现 Feign 发送 HTTP 请求时&#xff0c;底层会使用到别的客户端。下面列出…

微服务网关(四)tcp代理模块

微服务网关&#xff08;四&#xff09;tcp代理模块 tcp代理服务器的代理实现&#xff1a; 请求流程&#xff1a; 代理的启停方法 //并发执行 go func() {tcp_proxy_router.TcpServerRun() }()tcp_proxy_router.TcpServerStop()tcp_server 一次完整流程 tcp_server.go 首先…

JVM的垃圾回收机制

复制算法、Eden区和Survivor区 首先我们就来探索一下对于JVM堆内存中的新生代区域&#xff0c;是怎么进行垃圾回收的。 实际上JVM是把新生代分为三块区域的&#xff1a;1个Eden区&#xff0c;2个Survivor区。 其中Eden区占用80%的内存空间&#xff0c;每块Survivor各占用10%的内…

使用yolov5训练数据集笔记

准备工作 1. 安装labelimg labelimg:主要用于目标检测的目标框绘制&#xff0c;得到关于我们训练的边框位置、类别等数据 pip install labelimg2. 下载yolov5源码 我使用的是v7.0版本&#xff0c;直接下载即可&#xff0c;下载后解压出来 2.1 安装yolov5运行依赖包 进入…

SurfaceFlinger详解

SurfaceFlinger的定义 大多数应用在屏幕上一次显示三个层&#xff1a;屏幕顶部的状态栏、底部或侧面的导航栏以及应用界面。有些应用会拥有更多或更少的层&#xff08;例如&#xff0c;默认主屏幕应用有一个单独的壁纸层&#xff0c;而全屏游戏可能会隐藏状态栏&#xff09;。…

棱形打印--进阶2(Java)

棱形打印 问题 * *** ***** ******* ********* ******* ***** *** * * * …

centos上搭建nginx视频点播服务器(nginx+vod+lua http发送鉴权消息)

需求背景&#xff1a;想着搭建一个视频点播服务器&#xff0c;最后选择了nginxvod的方案&#xff0c;用lua脚本写拉流鉴权&#xff0c;但是环境搭建过程中又发现nginxvodlua的环境并不是很容易搭建&#xff0c;是nginxlua的环境&#xff0c;手动搭建比较麻烦&#xff0c;但还是…