【深度学习实验】线性模型(二):使用NumPy实现线性模型:梯度下降法

news2024/11/24 11:28:58

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 初始化参数

2. 线性模型 linear_model

3. 损失函数loss_function

4. 梯度计算函数compute_gradients

5. 梯度下降函数gradient_descent

6. 调用函数


一、实验介绍

        使用NumPy实现线性模型:梯度下降法

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

         线性模型梯度下降法是一种常用的优化算法,用于求解线性回归模型中的参数。它通过迭代的方式不断更新模型参数,使得模型在训练数据上的损失函数逐渐减小,从而达到优化模型的目的。

        梯度下降法的基本思想是沿着损失函数梯度的反方向更新模型参数。在每次迭代中,根据当前的参数值计算损失函数的梯度,然后乘以一个学习率的因子,得到参数的更新量。学习率决定了参数更新的步长,过大的学习率可能导致错过最优解,而过小的学习率则会导致收敛速度过慢。

具体而言,对于线性回归模型,梯度下降法的步骤如下:

  1. 初始化模型参数,可以随机初始化或者使用一些启发式的方法。

  2. 循环迭代以下步骤,直到满足停止条件(如达到最大迭代次数或损失函数变化小于某个阈值):

    a. 根据当前的参数值计算模型的预测值。

    b. 计算损失函数关于参数的梯度,即对每个参数求偏导数。

    c. 根据梯度和学习率更新参数值。

    d. 计算新的损失函数值,并检查是否满足停止条件。

  3. 返回优化后的模型参数。

       本实验中,gradient_descent函数实现了梯度下降法的具体过程。它通过调用initialize_parameters函数初始化模型参数,然后在每次迭代中计算模型预测值、梯度以及更新参数值。

0. 导入库

import numpy as np

1. 初始化参数

        在梯度下降算法中,需要初始化待优化的参数,即权重 w 和偏置 b。可以使用随机初始化的方式。

def initialize_parameters():
    w = np.random.randn(5)
    b = np.random.randn(5)
    return w, b

2. 线性模型 linear_model

def linear_model(x, w, b):
    output = np.dot(x, w) + b
    return output

3. 损失函数loss_function

         该函数接受目标值y和模型预测值prediction,计算均方误差损失。

def loss_function(y, prediction):
    loss = (prediction - y) * (prediction - y)
    return loss

4. 梯度计算函数compute_gradients

        为了使用梯度下降算法,需要计算损失函数关于参数 w 和 b 的梯度。可以使用数值计算的方法来近似计算梯度。

def compute_gradients(x, y, w, b):
    h = 1e-6  # 微小的数值,用于近似计算梯度
    grad_w = (loss_function(y, linear_model(x, w + h, b)) - loss_function(y, linear_model(x, w - h, b))) / (2 * h)
    grad_b = (loss_function(y, linear_model(x, w, b + h)) - loss_function(y, linear_model(x, w, b - h))) / (2 * h)
    return grad_w, grad_b

5. 梯度下降函数gradient_descent

        根据梯度计算的结果更新参数 w 和 b,从而最小化损失函数。

def gradient_descent(x, y, learning_rate, num_iterations):
    w, b = initialize_parameters()
    for i in range(num_iterations):
        prediction = linear_model(x, w, b)
        grad_w, grad_b = compute_gradients(x, y, w, b)
        w -= learning_rate * grad_w
        b -= learning_rate * grad_b
        loss = loss_function(y, prediction)
        print("Iteration", i, "Loss:", loss)
    return w, b

6. 调用函数

        执行梯度下降优化:调用 gradient_descent 函数并传入数据 x 和 y,设置学习率和迭代次数进行优化。

x = np.random.rand(5)
y = np.array([1, -1, 1, -1, 1]).astype('float')
learning_rate = 0.1
num_iterations = 100
w_optimized, b_optimized = gradient_descent(x, y, learning_rate, num_iterations)

        在上述代码中,每一次迭代都会打印出当前迭代次数和对应的损失值。通过不断更新参数 w 和 b,使得损失函数逐渐减小,达到最小化损失函数的目的。

希望这个详细解析能够帮助你优化代码并使用梯度下降算法最小化损失函数。如果还有其他问题,请随时提问!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1020734.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Lombok依赖

一.介绍 Project Lombok 是一个 Java 库,它会自动插入编辑器和构建工具,为您的 Java 增添趣味。永远不要再写另一个 getter 或 equals 方法,使用一个注释,您的类有一个功能齐全的构建器,自动化您的日志记录变量等等。…

2023陇剑杯

2023陇剑杯初赛WP HW hard_web_1 ​ 首先判断哪个是服务器地址 ​ 从响应包看,给客户端返回数据包的就是服务器 所以确定服务器地址是192.168.162.188​ 再从开放端口来看,长期开放的端口 客户端发送一个TCP SYN包(同步请求&#xff…

VisualBox QA

出现提示注册表错误,或者之前正常,重启VisualBox后,VM运行失败时,可通过正确卸载VisualBox,然后使用注册表清理软件(CCleaner)清理注册表后,重装VisualBox,即会正常。(一般用这个能解…

CSS Id和Class选择器

文章目录 CSS id 选择器示例 CSS class 选择器CSS id和class的区别和相同点 CSS id 选择器 CSS的id选择器是以“#”开头的,用于选择具有特定id属性的HTML元素。 在HTML文档中,每个id应该是全局唯一的,也就是说,每个id只能用于一…

复杂场景:揭秘新生代光伏独角兽企业的数据管理秘诀

项目背景 最新一个光伏独角兽诞生了。 投资界获悉,一道新能源科技股份有限公司(以下简称“一道新能”)完成Pre-IPO融资。经多家投资方核实,此轮投后估值近80亿元。 一道新能源科技股份有限公司,成立于2018年8月&…

就业创业证查询

这里写目录标题 问题描述结果 问题描述 全国就业创业证查询系统自改版本后不支持根据姓名身份证号查询了,从社保局查又需要证书编号。 结果 经过不谢努力找到了解决办法,可以根据身份证姓名批量查询人员是否有就业证。

【Flowable】使用UEL整合Springboot从0到1(四)

前言 在前面我们介绍了Springboot简单使用了foleable以及flowableUI的安装和使用,在之前我们分配任务的处理人的时候都是通过Assignee去指定固定的人的。这在实际业务中是不合适的,我们希望在流程中动态的去解析每个节点的处理人,当前flowab…

家里电脑怎样远程办公室电脑?快解析映射域名实现内网穿透

远程电脑怎么操作是大家比较关注的问题,特别是涉及内外网,不在同一个局域网内不同计算机间的远程连接访问,如家里电脑怎样远程办公室电脑?这里提供一种简便的异地远程方法:用快解析。通过快解析映射域名软件&#xff0…

【漏洞复现】Smanga未授权远程代码执行漏洞(CVE-2023-36076) 附加SQL注入+任意文件读取

文章目录 前言声明一、产品简介一、漏洞描述二、漏洞等级三、影响范围四、漏洞复现五、修复建议六、附加漏洞漏洞一、SQL注入漏洞二、任意文件读取 前言 Smanga存在未授权远程代码执行漏洞,攻击者可在目标主机执行任意命令,获取服务器权限。 声明 请勿利用文章内的相关技术从…

【面试题】——Java基础篇(33题)

文章目录 1. 八大基本数据类型分类2. 重写和重载的区别3. int和integer区别4. Java的关键字5. 什么是自动装箱和拆箱?6. 什么是Java的多态性?7. 接口和抽象类的区别?8. Java中如何处理异常?9. Java中的final关键字有什么作用&…

Java文字描边效果实现

效果: FontUtil工具类的完整代码如下: 其中实现描边效果的函数为:generateAdaptiveStrokeFontImage() package com.ncarzone.data.contentcenter.biz.img.util;import org.springframework.core.io.ClassPathResource; import org.springfr…

爱思唯尔——利用AI来改善医疗决策和科研

爱思唯尔(Elsevier)是一家全球性的多媒体出版公司,为教育、专业科学和医疗社区提供20,000多种产品,其中包括《柳叶刀》和《细胞》等领先的研究出版物。 该公司正处于数字化转型的第一阶段,将公司140年中发表在报告和期刊上的大量数据数字化。…

小米华为,化干戈为玉帛!

近日来,手机圈又掀起了各大厂家推出新品的高潮。首先是华为Mate60的推出,其自研的麒麟9000S芯片瞬间点燃了国内手机市场,得到了国内甚至国外业界人士的认可和好评。 而近日网上盛传的小米创始人雷军的“愿意加入华为技术生态圈”的邀请&…

CRM客户管理软件对出海企业的帮助与好处

2023我们走出了疫情的阴霾,经济下行压力大,面对内需的不足,国内企业纷纷选择出海,拓展海外业务增加企业营收。企业出海不是一件易事,有了CRM系统可以让公司事半功倍,下面就来说一说CRM客户管理软件能为出海…

亚马逊应该怎么快速提升排名,获取review?

跨境电商做久了,卖家都会陷入一个困境,到底是该坚持慢慢做好,还是要测评? 现在跨境电商平台人人都在刷,不刷单想成功真的很难,不是没可能,但是选品要非常好,而且你的listing也要做好…

网络编程 day1

1->x.mind网络编程基础 2->简述字节序的概念,并用共用体(联合体)的方式计算本机的字节序 1.字节序是指不同类型的CPU主机,内存存储多字节整数序列的方式 2.小端字节序:低序字节存储在低地址上 3.大端字节序&a…

ROS 入门

目录 简介 ROS诞生背景 ROS的设计目标 ROS与ROS2 安装ROS 1.配置ubuntu的软件和更新 2.设置安装源 3.设置key 4.安装 5.配置环境变量 安装可能出现的问题 安装构建依赖 卸载 ROS架构 1.设计者 2.维护者 3. 立足系统架构: ROS 可以划分为三层 ROS通信机制 话…

007-第一代软件需求整理

第一代软件需求整理 文章目录 第一代软件需求整理项目介绍需求来源需求来源1:竞品软件分析需求来源2:医生(市场)需求来源3:项目组内部需求来源4:软件组内部需求来源5:软件开发成员需求来源6&…

Java精品项目源码第61期垃圾分类科普平台(代号V061)

Java精品项目源码第61期垃圾分类科普平台(代号V061) 大家好,小辰今天给大家介绍一个垃圾分类科普平台,演示视频公众号(小辰哥的Java)对号查询观看即可 文章目录 Java精品项目源码第61期垃圾分类科普平台(代号V061)难度指数&…

LeetCode: 4. Median of Two Sorted Arrays

LeetCode - The Worlds Leading Online Programming Learning Platform 题目大意 给定两个大小为 m 和 n 的有序数组 nums1 和 nums2。 请你找出这两个有序数组的中位数,并且要求算法的时间复杂度为 O(log(m n))。 你可以假设 nums1 和 nums2 不会同时为空。 …