Python纯Numpy手撕SGD

news2024/11/15 21:29:56

文章目录

  • 简介
  • 问题建模
  • 数据加载和预处理
    • 数据加载
    • 预处理
    • 分batch
  • 损失函数
  • 训练
  • 运行

简介

本博客用多元线性回归展示如何从零实现一个随机梯度下降SGD, 不使用torch等AI框架

问题建模

给定一个数据集 X ∈ R N × ( D + 1 ) \large X \in \R^{N \times (D+1)} XRN×(D+1)和对应标签向量 Y ∈ R N \large Y \in \R^{N} YRN, 权重为 W ∈ R D + 1 \large W \in \R^{D+1} WRD+1(包含 ω 0 \large \omega_0 ω0), 其中 N \large N N为数据集规模, D \large D D为样本特征维度.

则预测的标签 Y ^ \large \hat{Y} Y^为:
Y ^ = X W \large \hat{Y} = XW Y^=XW
误差函数采用均方差MSE, 即 ℓ \large \ell
ℓ ( W ) = ∣ Y ^ − Y ∣ 2 2 N = ∣ X W − Y ∣ 2 2 N \large \ell(W) = \frac{|\hat{Y}-Y|_2}{2N}=\frac{|XW-Y|_2}{2N} (W)=2NY^Y2=2NXWY2
根据梯度下降理论, 误差函数 ℓ ( W ) \large \ell(W) (W)求导为:
∂ ∂ W ℓ ( W ) = ∂ ∂ W ∣ X W − Y ∣ 2 2 N = ∂ ∂ W ∣ X W − Y ∣ 2 2 N = ∂ ∂ ( W ) ( X W − Y ) T ( X W − Y ) 2 N = ∂ ∂ ( W ) W T X T X W − W T X T Y − Y T X W + Y T Y 2 N = X T ( X W − Y ) N \large \frac{\partial}{\partial W}\ell(W) = \frac{\partial}{\partial W}\frac{|XW-Y|_2}{2N} \\ \large = \frac{\partial}{\partial W}\frac{|XW-Y|^2}{2N} \\ \large = \frac{\partial}{\partial (W)}\frac{(XW-Y)^T(XW-Y)}{2N} \\ \large = \frac{\partial}{\partial (W)}\frac{W^TX^TXW-W^TX^TY-Y^TXW+Y^TY}{2N} \\ \large = \frac{X^T(XW-Y)}{N} W(W)=W2NXWY2=W2NXWY2=(W)2N(XWY)T(XWY)=(W)2NWTXTXWWTXTYYTXW+YTY=NXT(XWY)
此处理想情况下, 可以直接令导数 ∂ ∂ W ℓ ( W ) = 0 \frac{\partial}{\partial W}\ell(W)=0 W(W)=0, 解得 W \large W W的解析解为 W = ( X T X ) − 1 X T Y \large W=(X^TX)^{-1}X^TY W=(XTX)1XTY. 但是实际场景中, 由于噪声或者样本规模太少等问题, X T X \large X^TX XTX不一定是个可逆矩阵, 因此梯度下降是一个普遍的替代方法.

数据加载和预处理

数据加载

采用房价数据集, 从sklearn下载

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split


X, Y = fetch_california_housing(return_X_y=True)
X.shape, Y.shape	# (20640, 8), (20640, )

预处理

需要对样本 X \large X X增加一列全1的数据, 为了方便划分batch, 然后划分成训练集和测试集.

ones = np.ones(shape=(X.shape[0], 1))
X = np.hstack([X, ones])
validate_size = 0.2
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=validate_size, shuffle=True)

分batch

这里写一个函数, 每次返回一个batch的样本, 为节省内存空间, 采用生成器的形式

def get_batch(batchsize: int, X: np.ndarray, Y: np.ndarray):
    assert 0 == X.shape[0]%batchsize, f'{X.shape[0]}%{batchsize} != 0'
    batchnum = X.shape[0]//batchsize
    X_new = X.reshape((batchnum, batchsize, X.shape[1]))
    Y_new = Y.reshape((batchnum, batchsize, ))

    for i in range(batchnum):
        yield X_new[i, :, :], Y_new[i, :]

损失函数

def mse(X: np.ndarray, Y: np.ndarray, W: np.ndarray):
    return 0.5 * np.mean(np.square(X@W-Y))

def diff_mse(X: np.ndarray, Y: np.ndarray, W: np.ndarray):
    return X.T@(X@W-Y) / X.shape[0]

训练

首先定义超参数

lr = 0.001          # 学习率
num_epochs = 1000   # 训练周期
batch_size = 64     # |每个batch包含的样本数
validate_every = 4  # 多少个周期进行一次检验

定义训练函数

def train(num_epochs: int, batch_size: int, validate_every: int, W0: np.ndarray, X_train: np.ndarray, Y_train: np.ndarray, X_test: np.ndarray, Y_test: np.ndarray):
    loop = tqdm(range(num_epochs))
    loss_train = []
    loss_validate = []
    W = W0
    # 遍历epoch
    for epoch in loop:
        loss_train_epoch = 0
        # 遍历batch
        for x_batch, y_batch in get_batch(64, X_train, Y_train):
            loss_batch = mse(X=x_batch, Y=y_batch, W=W)
            loss_train_epoch += loss_batch*x_batch.shape[0]/X_train.shape[0]
            grad = diff_mse(X=x_batch, Y=y_batch, W=W)
            W = W - lr*grad

        loss_train.append(loss_train_epoch)
        loop.set_description(f'Epoch: {epoch}, loss: {loss_train_epoch}')

        if 0 == epoch%validate_every:
            loss_validate_epoch = mse(X=X_test, Y=Y_test, W=W)
            loss_validate.append(loss_validate_epoch)
            print('============Validate=============')
            print(f'Epoch: {epoch}, train loss: {loss_train_epoch}, val loss: {loss_validate_epoch}')
            print('================================')
    plot_loss(np.array(loss_train), np.array(loss_validate), validate_every)

运行

W0 = np.random.random(size=(X.shape[1], ))  # 初始权重
train(num_epochs=num_epochs, batch_size=batch_size, validate_every=validate_every, W0=W0, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test)

结果如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qZGpALoK-1676738656386)(C:\Users\15646\AppData\Roaming\Typora\typora-user-images\image-20230219003541195.png)]

最后画一下损失函数图:

from matplotlib import pyplot as plt
def plot_loss(loss_train: np.ndarray, loss_val: np.ndarray, validate_every: int):
    %matplotlib
    x = np.arange(0, loss_train.shape[0], 1)
    plt.plot(x, loss_train, label='train')
    x = np.arange(0, loss_train.shape[0], validate_every)
    plt.plot(x, loss_val, label='validate')
    plt.legend()
    plt.title('loss')
    plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/356355.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

centos7防火墙工具firewall-cmd使用

centos7防火墙工具firewall-cmd使用防火墙概述centos7防火墙工具firewall-cmd使用介绍firewalld的基本使用服务管理工具相关指令配置firewalld-cmd防火墙概述 防火墙是可以帮助计算机在内部网络和外部网络之间构建一道相对隔绝的保护屏障,从而保护数据信息的一种技…

Vulnhub 渗透练习(七)—— FRISTILEAKS: 1.3

环境搭建 下载链接 virtualbox 打开靶机设置为 host-only,攻击机同样。 具体可点此处 信息收集 开了个 80 端口。 用的是 apache 2.2.15 ,这个版本有个解析漏洞。 目录 根据首页的图片猜测 /fristi/ 目录(不过我没想到 -_-&#x…

由浅入深掌握各种 Python multiprocessing 进程间通信方式

由浅入深掌握各种 Python 多进程间通信方式1、为什么要掌握进程间通信2、进程间各类通信方式简介3、消息机制通信1) 管道 Pipe 通信方式2) 消息队列Queue 通信方式4、同步机制通信(1) 进程间同步锁 – Lock(2) 子进程间协调机制 -- Event5、共享内存方式通信(1) 共享变量(2) 共…

【Python】控制自己的手机摄像头拍照,并自动发送到邮箱

前言 嗨喽,大家好呀~这里是爱看美女的茜茜呐 今天这个案例,就是控制自己的摄像头拍照, 并且把拍下来的照片,通过邮件发到自己的邮箱里。 想完成今天的这个案例,只要记住一个重点:你需要一个摄像头 思路…

Android 7.0 OTA升级(高通)

文章目录1. Full OTA 方式升级介绍1.1 Full OTA 制作第一步:生成 msm89xx-target_files-eng.XXX.zip1.2 Full OTA 制作第二步:Modem 等非 HLOS 加入升级包的方法1.3 Full OTA 制作第三步:生成 update.zip 升级包2. Incremental OTA 方式升级介…

Android 基础知识4-2.6LinearLayout(线性布局)

一、LinearLayout的概述 线性布局(LinearLayout)主要以水平或垂直方式来排列界面中的控件。并将控件排列到一条直线上。在线性布局中,如果水平排列,垂直方向上只能放一个控件,如果垂直排列,水平方向上也只能…

Java基础-xml

1.xml 1.1概述 万维网联盟(W3C) 万维网联盟(W3C)创建于1994年,又称W3C理事会。1994年10月在麻省理工学院计算机科学实验室成立。 建立者: Tim Berners-Lee (蒂姆伯纳斯李)。 是Web技术领域最具权威和影响力的国际中立性技术标准机构。 到目前为止&#…

python基础语法【自用】

✨始发站🚩Python的基础语法,冲冲冲! 🚩注:本篇为python基础语法篇,因博主之前使用java,所以本基础语法篇实为自用丐版! 🌲 你好,世界! 安装环境…

虚拟机快照

1. 快照有什么作用? 通俗理解:快照就是备份。 2. VMware Workstation 和 VMware Fusion 都支持制作快照去使用 一、快照 保存当前虚拟机状态。可以恢复 二、 在VMware Workstation Pro中制作并还原快照 三、在VMware Fusion Pro中制作并还原快照 快照制…

210天从外包踏进华为跳动那一刻,我泪目了

前言 没有绝对的天才,只有持续不断的付出。对于我们每一个平凡人来说,改变命运只能依靠努力幸运,但如果你不够幸运,那就只能拉高努力的占比。 2021年4月,我有幸成为了华为的一名高级测试工程师,正如标题所…

【软件测试】python接口自动化测试编写脚本,资深测试总结方法,你的实用宝典......

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 接口测试&#xff0…

美团前端一面手写面试题

实现斐波那契数列 // 递归 function fn (n){if(n0) return 0if(n1) return 1return fn(n-2)fn(n-1) } // 优化 function fibonacci2(n) {const arr [1, 1, 2];const arrLen arr.length;if (n < arrLen) {return arr[n];}for (let i arrLen; i < n; i) {arr.push(arr[…

vulnhub Kioptrix4

总结&#xff1a;sql注入&#xff0c;受限shell绕过&#xff0c;mysql提权 目录 下载地址 漏洞分析 信息收集 sql注入 ssh登录绕过受限shell 提权 下载地址 Kioptrix4_Hyper_v.rar (Size: 210 MB)Download: http://www.kioptrix.com/dlvm/Kioptrix4_Hyper_v.rarDownload …

Linux驱动开发详细解析

Linux驱动开发详细解析 驱动概念 驱动与底层硬件直接打交道&#xff0c;充当了硬件与应用软件中间的桥梁。 具体任务 读写设备寄存器&#xff08;实现控制的方式&#xff09;完成设备的轮询、中断处理、DMA通信&#xff08;CPU与外设通信的方式&#xff09;进行物理内存向虚…

关于字符设备驱动的通用概念和写法

概述 设备驱动程序可以使用模块的方式动态加载到内核中去。加载模块的方式与以往的应用程序开发有很大的不同。以往在开发应用程序时都有一个 main()函数作为程序的入口点&#xff0c;而在驱动开发时却没有 main()函数&#xff0c;模块在调用 insmod 命令时被加载&#xff0c;…

JVM学习笔记一:类加载子系统

目录 前言 类加载子系统的作用 类加载器角色的位置 类加载器分类 虚拟机自带的加载器 启动类加载器&#xff08;引导类加载器&#xff09; 扩展类加载器 系统类加载器 用户自定义类加载器 什么时候需要自定义类加载器&#xff1f; 如何自定义类加载器&#xff1f; …

【验证码的识别】—— 极验验证码的识别

前言 &#xff08;结尾有彩蛋欧&#xff09; 目前&#xff0c;许多网站采取各种各样的措施来反爬虫&#xff0c;其中一个措施便是使用验证码。随着技术的发展&#xff0c;验证码的花样越来越多。验证码最初是几个数字组合的简单的图形验证码&#xff0c;后来加入了英文字母和混…

《计算机系统基础》——计算机系统导论

文章目录《计算机系统基础》——计算机系统导论计算机的基本组成程序开发与执行过程机器语言汇编语言高级语言程序的转换处理程序的数据流动计算机系统层次结构早期计算机系统1GL2GL现代计算机系统3GL4GL指令集体系结构《计算机系统基础》——计算机系统导论 &#x1f680;接下…

LaTeX中表格过宽解决方案

最近使用LaTeX处理表格时遇到了一件十分棘手的问题&#xff0c;由于内容较多将表格分成了好多列&#xff0c;但将内容填入表格时由于表格宽度过大&#xff0c;导致表格右边溢出了页面无法查看&#xff0c;查阅大量资料与博文后给出如下解决方案&#xff0c;全文代码已部署在Ove…

C#基础练习题,编程题汇总

C#基础练习题&#xff0c;编程题汇总一、C#提取输入的最大整数二、秒数换算为相应的时、分、秒三、C#计算电梯运行用时demo四、C#用一维数组求解问题五、C#程序教小学生学乘法六、C#winfrm简单例题七、C#类继承习题八、C#绘图例子一、C#提取输入的最大整数 编程实现在一行内输…