python使用 数值微分法 求梯度,实现单层线性回归

news2024/11/28 18:54:52

文章目录

  • 模型
  • 构建数据
  • 数值微分实现(梯度计算)
  • 模型封装
  • 运行测试
  • 运行结果

主要介绍 数值微分法 求梯度,以及基于此对参数作随机梯度下降,并封装一个简单的线性回归模型以作调试,最后绘制loss图像。

模型

y = X W + b y = XW + b y=XW+b

y为标量,X列数为2. 损失函数使用均方误差。

构建数据

def build_data(weights, bias, num_examples):  
    x = np.random.randn(num_examples, len(weights))  
    y = x.dot(weights) + bias  
    # 给y加个噪声  
    y += np.random.rand(1)  
    return x, y  
  
  
def data_iter(features, labels, batch_size):  
    num_examples = len(features)  
    # 按样本数量构造索引  
    indices = list(range(num_examples))  
    # 打乱索引数组  
    np.random.shuffle(indices)  
    for i in range(0, num_examples, batch_size):  
        batch_indices = np.array(indices[i:min(i + batch_size, num_examples)])  
        yield features[batch_indices], labels[batch_indices]

数值微分实现(梯度计算)

就是求偏导。

# 基于数值微分+中心差分法 求偏导(梯度)  
def numerical_gradient(f, x):  
    h = 1e-4  # 0.0001  
    grad = np.zeros_like(x)  
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])  
    while not it.finished:  
        idx = it.multi_index  
        tmp_val = x[idx]  
        x[idx] = float(tmp_val) + h  
        fxh1 = f(x)  # f(x+h)  
  
        x[idx] = tmp_val - h  
        fxh2 = f(x)  # f(x-h)  
        grad[idx] = (fxh1 - fxh2) / (2 * h)  
  
        x[idx] = tmp_val  # 还原值  
        it.iternext()  
    return grad

模型封装

class Network:  
    def __init__(self, input_size, output_size, weight_init_std=0.01):  
        self.params = {'w1': np.random.rand(input_size, output_size),  
                       'b1': np.array([0.0])}  
  
    def predict(self, x):  
        w1, b1 = self.params['w1'], self.params['b1']  
        return x.dot(w1) + b1  
  
    def loss(self, x, y):  
        pred_y = self.predict(x)  
        return np.mean(np.square(y - pred_y))  
  
    def numerical_gradient(self, x, y):  
        loss_w = lambda w: self.loss(x, y)  
        grads = dict()  
        grads['w1'] = numerical_gradient(loss_w, self.params['w1'])  
        grads['b1'] = numerical_gradient(loss_w, self.params['b1'])  
        return grads

运行测试

if __name__ == '__main__':  
    start = time.perf_counter()  
  
    # np.random.seed(1)  
    true_w1 = np.random.rand(2, 1)  
    true_b1 = np.random.rand(1)  
    # true_w1 = np.array([[3.0], [4.0]])  
    # true_b1 = np.array([5.0])    x_train, y_train = build_data(true_w1, true_b1, 5000)  
  
    net = Network(2, 1, 0.01)  
    init_loss = net.loss(x_train, y_train)  
  
    print(net.params)  
  
    loss_history = list()  
    loss_history.append(init_loss)  
  
    num_epochs = 2  
    batch_size = 50  
    learning_rate = 0.01  
    for i in range(num_epochs):  
        # running_loss = 0.0  
        for x, y in data_iter(x_train, y_train, batch_size):  
            grads = net.numerical_gradient(x, y)  
            for key in grads:  
                net.params[key] -= learning_rate * grads[key]  
            running_loss = net.loss(x, y)  
            loss_history.append(running_loss)  
  
        # current_loss = net.loss(x_train, y_train)  
        # loss_history.append(current_loss)  
        # print(f'第{i}次:{net.params}')  
  
    plt.title("基于 数值微分+中心差分法 的单层简单线性模型")  
    plt.xlabel("epoch")  
    plt.ylabel("loss")  
    plt.plot(loss_history, linestyle='dotted')  
    plt.show()  
  
    # print(loss_history)  
    print(f'初始损失值:{init_loss}')  
    print(f'最后一次损失值:{loss_history[-1]}')  
  
    print()  
  
    print(f'正确参数: true_w1={true_w1}, true_b1={true_b1}')  
    print(f'预测参数: true_w1={net.params["w1"]}, true_b1={net.params["b1"]}')  
    print()  
  
    end = time.perf_counter()  
    print(f"运行时间:{(end - start)*1000}毫秒")

运行结果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2034652.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java:接口interface

文章目录 接口interface好处为什么要用接口 接口案例需求思路代码Student.javaClassManage.javaStudentOperator 接口StudentOperatorImpl1.javaStudentOperatorImpl2.javaTest.java 黑马程序员学习笔记 接口interface 接口中:变量默认为常量,方法默认为…

上门预约o2o系统源码开发及商业模式探索

随着互联网的飞速发展,O2O(Online to Offline)模式已成为连接线上与线下服务的重要桥梁。上门预约O2O系统作为这一模式下的典型应用,通过整合线上线下资源,为用户提供便捷、高效、个性化的上门服务体验。本文将从商业模…

WebStorm 2024 for Mac/Win:JavaScript开发的高效利器

WebStorm 2024 for Mac/Win是一款专为前端开发者和全栈工程师设计的集成开发环境(IDE),由JetBrains公司精心打造。这款软件以其强大的功能和卓越的性能,在JavaScript及相关技术的开发领域中脱颖而出,成为众多开发者的首…

无人机电池的使用寿命!

无人机电池的循环寿命一般在100次到500次之间,具体取决于电池类型、质量和使用条件。高品质电池和正确的使用方式可以延长电池的循环寿命。 避免极端温度 避免在过高或过低的温度下使用无人机电池,以免影响电池性能和寿命。 正确存储 将电池存放在干燥…

C语言——预处理详解(下)

目录 前言 #和## 1.#运算符 2.##运算符 命名约定 #undef 命令行定义 条件编译 1.单分支条件编译 2.多分支条件编译 3.判断是否被定义 4.嵌套指令 头文件的包含 1.头文件被包含的方式 (1)本地文件包含 (2)库文件包含 2.嵌套文件包含 其他预处理指令 1.#error…

如何用Python进行数据可视化、科技图表绘制?

目录 写在前面 推荐图书 推荐理由 写在最后 写在前面 有了它,科技图表绘制、数据可视化真的毫无难度! 推荐图书 《Python数据可视化:科技图表绘制》(芯智)【摘要 书评 试读】- 京东图书 图书简介 《Python数据可视化:科技图表绘制》结…

mfc140.dll丢失如何修复,一步步教你如何解决mfc140.dll丢失,让电脑快速恢复正常状态!

mfc140.dll是 Microsoft Foundation Class (MFC) Library 的一部分,它是一个用于开发 Windows 应用程序的 C 库。当系统报告mfc140.dll丢失时,通常意味着某个应用程序需要这个 DLL 文件来运行,但系统中没有找到它。那么mfc140.dll丢失如何修复…

Ubuntu下提升高并发socket最大连接数限制

文章目录 前言1. limits.conf修改2. /etc/pam.d修改3. /etc/sysctl.conf修改4. ulimit设置5.重启系统即可生效参考文档 前言 linux系统默认ulimit为1024个访问 用户最多可开启的程序数目。一般一个端口(即一个进程)的最高连接为2的16次方65536。 查看全…

TPshop商城的保姆教程(Ubuntu)

1.上传TPSHOP源码 选择适合自己的版本下载 TPshop商城源文件下载链接: 百度网盘 请输入提取码 上传tpshop的源码包到特定目录/var/www/html 切换到/var/www/html 目录下 cd /var/www/html修改HTML目录下所有文件权限 chmod -R 777 * 2.打开网址配置 TPshop安…

如何以编程方式解析 XCResult 包的内容

文章目录 介绍查找 XCResult 包分享 XCResult 包 解析 XCResult 包自动解析 XCResult 包的内容 使用 XCResultKit 解析包的内容初始化库获取调用记录 获取测试信息导出屏幕录制 可运行 Demo初始化 Swift Package编写主文件代码解释运行 Demo 结论 介绍 XCResult 包是一个包含运…

Apache SeaTunnel 2.3.5 Zeta-Server集群环境搭建与使用

作者 | 月影幽篁 在当前数据驱动的业务环境中,快速且高效的数据处理能力至关重要。Apache SeaTunnel以其卓越的性能和灵活性,成为数据工程师和开发者的首选工具之一。本文将介绍如何在集群环境中搭建Apache SeaTunnel 2.3.5版本的 Zeta-Server&#xff…

期权强大优势之一的杠杆是什么?!

今天带你了解期权强大优势之一的杠杆是什么?!期权是一种合约,该合约赋予持有人在某一特定日期以固定价格买入或卖出一种资产的权利。 期权杠杆是指使用较少的资金控制相对较大金额的股票或其他资产的能力。 期权提供了买入或卖出标的资产的…

U盘救星在此!年度免费数据恢复软件TOP榜

现在这社会,数字信息太重要了,工作文件、学习笔记,还有那些记录美好时光的照片和视频,要是一不小心丢了,那可真是急死人。不过,幸运的是,现在有数据恢复软件,它们就像是数据的救星&a…

Qt多线程编程-run()方法

本文介绍Qt多线程编程-run()方法。 Qt多线程编程主要有2种方法,前面已经介绍了moveToThread()方法,本文介绍另外一种方法run()方法,并给出一个实例参考。 1.基本原理 run()方法首先需要定义一个基于QThread的派生类,QThread类是…

cAdvisor+prometheus+grafana搭建监控页面并嵌入自定义页面中

三者关系 一般公司会有很多docker主机,那么就需要对docker进行监控了,docker监控可以采用docker stats配合shell命令来取值做监控,但是无法传递给prometheus进行采集,zabbix监控docker又比较麻烦,因此就有了谷歌的cad…

Python开源项目周排行 2024年第13周

#2024年第13周2024年8月5日1roop一款基于深度学习框架TensorFlow和Keras开发的单图换脸工具包,提供了丰富的功能和简洁易用的界面,使得用户可以轻松实现单图换脸操作。支持多张人脸替换成同一个人脸,勾选多人脸模式即可 人脸替换 高清修复自…

RCE绕过方式

目录 小于8个字符突破限制 无字母数字执行 php7的做法 php5的思考 PHP5shell 深入理解glob通配符 构造POC,执行任意命令 无参数读文件和RCE总结 代码解读 构造. 另一种构造方法 小于8个字符突破限制 但也只能执行一些非常短的命令,没有什么意义…

【JavaSec】 代码审计01-SpringMVC图书购物系统

【JavaSec】 代码审计01-SpringMVC图书购物系统 文章目录 【JavaSec】 代码审计01-SpringMVC图书购物系统前期部署用户管理修改删除 商品管理修改 普通用户注册 源码地址:https://github.com/Laverrr/bookstore 前期部署 问题一: 启动后报错 Cookie值…

RabbitMQ应用问题 - 消息顺序性保证、消息积压问题

文章目录 MQ 消息顺序性保证概述原因分析解决方案基于 spring-cloud-stream 实现分区消费 消息挤压问题概述原因分析解决方案 MQ 消息顺序性保证 概述 a)消息顺序性:消费者消费的消息的顺序 和 生产者发送消息的顺序是一致的. 例如 生产者 发送消息顺序…

centos7 xtrabackup mysql(8)压缩 增量备份(3)

centos7 xtrabackup mysql(8)压缩 增量备份(3) 添加数据1 添加数据测试一下 测试主从是否可以 主机端 mysql -u root -p 1234aA~1 show databases ; use company_pro; show tables ; insert into employee(name) value (‘2024…