24/8/5算法笔记 BGD,SGD,MGD梯度下降

news2024/11/16 10:17:43

今日对比不同梯度下降的代码

1.BGD大批量梯度下降(一元一次)

首先导入库

import numpy as np

import matplotlib.pyplot as plt

随机生成线性回归函数

X=np.random.rand(100,1)

w,b=np.random.randint(1,10,size=2)

#增加噪声,更像真实数据
#numoy广播机制
y= w*X+b+np.random.randn(100,1)
plt.scatter(X,y)

使用了 np.concatenate 来将一个全为1的列向量添加到 X 的左边,从而在数据集中包含截距项

# 将b作为偏置项,截距,对应的系数,理解为1
X=np.concatenate([X,np.full(shape = (100,1),fill_value=1)],axis=1)
X[:10]

大批量梯度下降函数

#循环次数
epoches = 10000

#学习率
eta =0.01
t0 = 5
t1 = 1000
#优化:学习率变化
#t 梯度下降的次数,逆时衰减,随着梯度下降次数增加学习率,变小
def learning_rate_shedule(t):
    return t0/(t + t1)
#要求解的系数
theta = np.random.randn(2,1)


#梯度下降的次数
t=0
for i in range(epoches):
    #批量梯度下降
    g = X.T.dot(X.dot(theta)-y)#根据公式计算梯度
    eta = learning_rate_shedule(t)#随着梯度下降增加,学习率下降
    theta = theta - eta*g
print('真实的斜率,截距是:',w,b)
print('BGD求解的斜率,截距是:',theta)#因为增加了噪声所以有差距

大批量(多元一次)

import numpy as np
import matplotlib.pyplot as plt
X=np.random.rand(100,8)

w=np.random.randint(1,10,size=(8,1))
b=np.random.randint(1,10,size=1)
#增加噪声,更像真实数据
#numoy广播机制
y= X.dot(w)+b+np.random.randn(100,1)

#截距当成偏执项
X =np.concatenate ([X,np.full(shape=(100,1),fill_value=1)],axis=1)

#循环次数
epoches = 10000

#学习率
eta =0.01
t0 = 5
t1 = 1000
#优化:学习率变化
#t 梯度下降的次数,逆时衰减,随着梯度下降次数增加学习率,变小
def learning_rate_shedule(t):
    return t0/(t + t1)
#要求解的系数
theta = np.random.randn(9,1)


#梯度下降的次数
t=0
for i in range(epoches):
    #批量梯度下降
    g = X.T.dot(X.dot(theta)-y)#根据公式计算梯度
    eta = learning_rate_shedule(t)#随着梯度下降增加,学习率下降
    theta = theta - eta*g
print('真实的斜率,截距是:',w,b)
print('BGD求解的斜率,截距是:',theta)#因为增加了噪声所以有差距


2.SGD随机梯度下降

导入库

import numpy as np
import matplotlib.pyplot as plt
X=np.random.rand(100,1)

w,b=np.random.randint(1,10,size=2)

#增加噪声,更像真实数据
#numoy广播机制
y= w * X + b + np.random.randn(100,1)
plt.scatter(X,y)

#编置,x也需要增加#full_like(用于创建一个与给定数组 a 形状和类型相同,但填充了指定值的新数组。)
X = np.concatenate([X,np.full_like(X,fill_value=1)],axis = 1)

#循环次数
epoches = 100

t0 = 5
t1 = 1000
#优化:学习率变化
#t 梯度下降的次数,逆时衰减,随着梯度下降次数增加学习率,变小
def learning_rate_shedule(t):
    return t0/(t + t1)
theta = np.random.randn(2,1)

count=0
for t in range(epoches):
    index = np.arange(100)
    np.random.shuffle(index)#洗牌,打乱
    
    #numpy花式索引
    X=X[index]
    y=y[index]
    for i in range(100):
        #随机抽样!!!抽了一个样本
        X_i=X[i]
        y_i=y[i]
    
        #根据这个样本,逆行计算梯度
        #单个样本依然是矩阵
        g = X_i.T.dot(X_i-theta* y_i)
        eta = learning_rate_shedule(t+count)
        count +=1
        theta -= eta*g
print('真实的斜率,截距是:',w,b)
print('SGD求解的斜率,截距是:',theta)

3.MGD小批量梯度下降

import numpy as np
import matplotlib.pyplot as plt

X=np.random.rand(100,1)

w,b=np.random.randint(1,10,size=2)

y= w * X + b + np.random.randn(100,1)

X = np.c_[X,np.ones((100,1))]

t0,t1=5,500
def learning_rate_shedule(t):
    return t0/(t + t1)

epoches = 100
n=100
batch_size = 16
num_batches = int(n/batch_size)


#5初始化 w0..wn,标准正态分布创建
o =np.random.randn(2,1)

#6 多次for循环实现梯度下降,最终结果收敛
for epoch in range(epoches):
    #在双层for循环之间,每个轮次开始分批次迭代之前打乱数据索引顺序
    index = np.arange(n)
    np.random.shuffle(index)#打乱
    X = X[index]
    y = y[index]
    for i in range(num_batches):
        #切片
        X_batch = X[i * batch_size:(i+1)*batch_size]
        y_batch = y[i * batch_size:(i+1)*batch_size]
        
        g = X_batch.T.dot(X_batch.dot(o) - y_batch)
        
        learning_rate = learning_rate_shedule(epoch*n+i)
        o = o-learning_rate*g
print('真实斜率和截距',w,b)
print('梯度下降计算斜率和截距',o)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1981965.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql的安装与基本操作

1、centos7 中安装 mysql 8.x(1)下载安装包 wget https://downloads.mysql.com/archives/get/p/23/file/mysql-8.0.33-1.el7.x86_64.rpm-bundle.tar(2)解压 tar -xf mysql-8.0.33-1.el7.x86_64.rpm-bundle.tar(3&…

PXE实验-使用kickstart批量自动部署操作系统

实验准备:rhel7.9具备图形界面的虚拟机,虚拟机网络配置可用,VMware 中NAT的DHCP功能关闭,虚拟机中yum源已配置好 1.在虚拟机中安装kickstart并且启动图形制作工具 yum install system-config-kickstart.noarch -y system-config…

【第13章】Spring Cloud之Gateway全局异常处理

文章目录 前言一、异常处理1. 响应实体类2. 异常处理类 二、单元测试1. 无可用路由2. 服务不可用 总结 前言 网关作为我们对外服务的入口起着至关重要的作用,我们必须保证网关服务的稳定性,下面来为网关服务增加异常处理机制。 一、异常处理 1. 响应实…

动态规划.

目录 (一)递归到动规的一般转化方法 (二)动规解题的一般思路 1. 将原问题分解为子问题 2. 确定状态 3. 确定一些初始状态(边界状态)的值 4. 确定状态转移方程 (三)能用动规解…

小程序 发布流程

1: 点击HbuilderX 菜单栏上的 发行> 小程序-微信(适用于uni-app) 2: 第二步: 需要再弹出框中填写发布系小程序的名称和AppId 之后, 点击发行按钮。 3:在Hbuilder 的控制台中 查看小程序发布编译的进度。…

VMware17下载与安装

1.下载 通过百度网盘分享的文件:VMware17 链接:https://pan.baidu.com/s/1gCine3d3Rp_l3NYAu5-ojg 提取码:ek25 --来自百度网盘超级会员V3的分享 2.安装

k8s(六)---pod

六、pod(k8s中最小的调度单元) pod中可以有一个或多个容器 1、官网 2、简介 Pod是k8s中最小的调度单元、Pod具有命名空间隔离性 3、如何创建一个Pod资源(主要两种方式) 1)kubctl run ①kubectl run nginx–imagereg…

k8s(七)---标签

一、标签(适用于资源定位) label是一对key和value,创建标签后,方便对资源进行分组管理。 1.帮助 kubectl label --help 2.打标签 pod 针对于pod打标签 key是env,value是test kubectl label po nginx envtest 给pod打标签 3.查看 k…

Qcustomplot绘制实时动态曲线??

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

uviewPlus 组件库的使用

文章目录 1、 1、 全局引入样式文件 (该语句是文档中提及但是不存在的语句)

mysql的安装配置与基础用户使用

第五周 周一 早 mysql安装配置 1.官网下载或者wget [rootmysql ~]# ls anaconda-ks.cfg initserver.sh mysql-8.0.33-1.el7.x86_64.rpm-bundle.tar mysql-community-client-8.0.33-1.el7.x86_64.rpm mysql-community-client-plugins-8.0.33-1.el7.x86_64.rpm mysql-c…

Dockerfile 容器镜像制作 私有仓库

Dockerfile 概述 制作镜像 FROM CMD # ENTRYPOINT 与 CMD 执行方式为 ${ENTRYPOINT} ${-${CMD}} apache 镜像 nginx 镜像 php-fpm 镜像 docker 私有仓库

单位工作邮箱如何实现快速开通

单位工作邮箱如何实现快速开通?单位工作邮箱快速开通需分析需求、选合适服务商、备材料、注册验证配置MX记录、创账户。开通前需测试邮件收发、功能及安全,确保稳定运行。本文将详细介绍单位工作邮箱的前期准备以及快速开通的流程。 一、需求分析与规划…

有了谷歌账号在登录游戏或者新APP、新设备时,要求在手机上点击通知和数字,怎么办?

有的朋友可能遇到过,自己注册或购买了谷歌账号以后,在自己的手机上可以正常登录,也完成了相关的设置,看起来一切都很完美,可以愉快地玩耍了。 但是,随后要登录一个游戏的时候(或者登录一个新的…

[Web安全架构] HTTP协议

文章目录 前言1. HTTP1 . 1 协议特点1 . 2 URL1 . 3 Request请求报文1 . 3 .1 请求行1 . 3 .2 请求头1 . 3 .3 请求正文1 . 3 .4 常见传参方式 1 . 4 Response响应报文1 . 4 .1 响应行1 . 4 .2 响应头1 . 4 .3 响应正文 2. Web会话2 .1 Cookie2 .2 Session2 .3 固定会话攻击 前…

TypeScript循环

循环 循环 一直重复的做某一件事 循环需要的必须条件:1.开始条件 2.结束条件3.变量的更新 while循环允许程序在满足特定条件时重复执行一段代码块,直到条件不再满足为止 结构:while(条件表达式){ //需要重复执行的代码块 } let a:numb…

【ESP01开发实例】-ESP-01网络天气数据获取

ESP-01网络天气数据获取 文章目录 ESP-01网络天气数据获取1、硬件准备与接线2、天气数据获取准备3、代码实现在本文中,将展示如何使用 ESP8266 (ESP-01) Wi-Fi 模块构建一个简单的互联网气象站。 ESP8266 可以访问互联网(网页)并从为全球许多城市提供免费天气信息的网站获取…

监控员工电脑的软件有哪些?四款监控员工电脑的软件分享!

古之治事,必明察秋毫,以驭群才。今之世,科技日新,监控之术亦随之而变。有软件四款,专司员工电脑之监,以助上司洞察细微,安内攘外。今略陈其要,尤以“安企神”为详。 一、安企神软件 …

Linux笔记 --- 传统链表

目录 链表 单向链表 单向循环链表 双向链表 设计表 初始化 在auchor后插入节点, 在auchor前插入节点 删除节点 传统链表 通过使用链表我们可以将一个数组中的数据分开到不同位置存放并使用指针指向他们,使之逻辑相连,解决了顺序存储所需要…

软件更新中的风险识别与质量保证机制分析

​ ​ 您好,我是程序员小羊! “微软蓝屏”事件暴露了网络安全哪些问题? 近日,一次由微软视窗系统软件更新引发的全球性“微软蓝屏”事件,不仅成为科技领域的热点新闻,更是一次对全球IT基础设施韧性与安全性…