机器学习——4.案例: 简单线性回归求解

news2024/12/27 11:49:53

案例目的

寻找一个良好的函数表达式,该函数表达式能够很好的描述上面数据点的分布,即对上面数据点进行拟合。

求解逻辑步骤

  1. 使用Sklearn生成数据集
  2. 定义线性模型
  3. 定义损失函数
  4. 定义优化器
  5. 定义模型训练方法(正向传播、计算损失、反向传播、梯度清空)
  6. 模型训练
  7. 模型预测与线性关系展示

代码实现

import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
import torch


# 生成数据集 n_samples-样本数量,n_features-自变量数量,random_state-随机种子,noise-噪声
data = datasets.make_regression(n_samples=100,n_features=1,random_state=5,noise=10)
X,Y = data
# 数据集转换成张量
X = torch.from_numpy(X.astype(np.float32))
Y = torch.from_numpy(Y.astype(np.float32))
# 行列形状要相同
Y = Y.view(100,1)

# 线性模型函数的定义
n_samples,n_features = X.size()
model = torch.nn.Linear(n_features,1)

# 定义损失函数
loss = torch.nn.MSELoss()

# 定义优化器
learn_rate = 0.01 
optimizer = torch.optim.SGD(model.parameters(),lr=learn_rate)

# 实现梯度下降函数
def gradient_descent():
    # 正向传播
    pre_y = model(X)
    # 计算损失
    l = loss(pre_y,Y)
    # 反向传播
    l.backward()
    # 梯度更新
    optimizer.step()
    # 梯度清空
    optimizer.zero_grad()
    return l,list(model.parameters())

# 模型训练
for i in range(500):
    l,parameters = gradient_descent()
    print(l,parameters)

# 模型预测
predect = model(X)

# X,Y线性拟合效果展示
plt.scatter(X,Y)
plt.plot(X,predect.detach().numpy(),color="r")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1647818.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Implicit Diffusion Models for Continuous Super-Resolution

CVPR2023https://github.com/Ree1s/IDM问题引入: – LIIF方法可以实现任意分辨率的输出,但是因为是regression-based方法,所以得到的结果缺少细节,而生成的方法(gan-based,flow-based,diffusion-based等)可以生成细节&…

JavaScript中的RegExp和Cookie

个人主页:学习前端的小z 个人专栏:JavaScript 精粹 本专栏旨在分享记录每日学习的前端知识和学习笔记的归纳总结,欢迎大家在评论区交流讨论! 文章目录 🔆RegExp 🎲 1 什么是正则表达式 🎲2 创建…

从招标到合作:如何筛选与企业需求匹配的6sigma咨询公司

在市场竞争激烈的环境中,领军企业需要不断改进和创新才能在行业中保持竞争优势。为了解决产品质量、生产流程和客户满意度等方面的挑战,许多企业选择与6sigma咨询公司合作,推动企业的全面变革和持续发展。下面是企业在选择合作伙伴时通常会经…

一、Redis五种常用数据类型

Redis优势: 1、性能高—基于内存实现数据的存储 2、丰富的数据类型 5种常用,3种高级 3、原子—redis的所有单个操作都是原子性,即要么成功,要么失败。其多个操作也支持采用事务的方式实现原子性。 Redis特点: 1、支持…

Golang日志实战教程:掌握log与syslog库的高效使用

Golang日志实战教程:掌握log与syslog库的高效使用 简介理解 Golang 的 log 库基本概念创建日志记录器自定义日志记录器日志级别 深入 syslogsyslog 的基础配置和使用 syslog高级应用 日志格式化与管理日志格式化日志文件管理 日志的高级应用集成第三方日志框架使用 …

Python程序中温度更新出现振荡问题的分析和解决方案

在处理温度更新出现振荡问题时,可以考虑以下分析和解决方案:检查温度更新算法是否正确,可能存在错误导致振荡。检查温度更新的步长(时间步长)是否合适,步长过大可能导致振荡。检查系统动力学模型是否准确&a…

场外个股期权和场内个股期权的优缺点是什么?

场外个股期权和场内个股期权的优缺点 场外个股期权是指在沪深交易所之外交易的个股期权,其本质是一种金融衍生品,允许投资者在股票交易场所外以特定价格买进或卖出证券。场内个股期权是以单只股票作为标的资产的期权合约,其内在价值是基于标…

如何用Kimi,5秒1步生成流程图

引言 在当前快节奏的工作环境中,拥有快速、专业且高效的工具不可或缺。 Kimi不仅能在5秒内生成专业的流程图(kimi),还允许实时编辑和预览,大幅简化了传统流程图的制作过程。 这种迅速的生成能力和高度的可定制性使得…

员工账号生命周期如何“全场景”自动化管理?

当企业在信息化建设中引入越来越多的业务系统时,必然存在系统内账号互相独立、无法打通的情况。一有人事变动,HR、IT 管理员、应用管理员、业务部门主管等人就需要在系统里手动更新账号状态。重复、低效,且不可避免出现安全隐患。困扰着 IT 管…

冯喜运:5.7全球紧张局势中,黄金原油投资者转向需谨慎

【黄金消息面分析】:周一(5月6日),现货黄金触底回升,盘中交投于2320美元附近。自美国4月非农就业数据出炉和美联储主席鲍威尔货币政策新闻发布会以后,现货黄金从4月12日的历史高点2431美元下跌了大约6.3%&a…

AI口语对话训练有哪些软件?推荐这5款,简单易用

AI口语对话训练有哪些软件?AI口语对话训练软件在近年来得到了飞速的发展,为语言学习者提供了更为便捷、高效的学习方式。它们借助先进的自然语言处理技术和机器学习算法,不仅模拟了真实对话场景,还提供了个性化的学习建议和即时反…

笔试强训Day15 二分 图论

平方数 题目链接&#xff1a;平方数 (nowcoder.com) 思路&#xff1a;水题直接过。 AC code&#xff1a; #include<iostream> #include<cmath> using namespace std; int main() {long long int n; cin >> n;long long int a sqrtl(n);long long int b …

苏州金龙荣获首届无人扫地机器人演示比赛“竞技领跑奖”

4月30日&#xff0c;2024年苏州市首届无人扫地机器人演示比赛在高新区思益街展开比拼。五家企业参赛在道路上实地比拼无人扫地机器人技术&#xff0c;通过清扫垃圾、识别路障等环节展现城市清洁的“未来场景”。经过角逐&#xff0c;苏州金龙的无人驾驶清扫车获得步道演示比赛“…

送给正在入行的小白:最全最有用的网络安全学习路线已经安排上了

在这个圈子技术门类中&#xff0c;工作岗位主要有以下三个方向&#xff1a; 安全研发安全研究&#xff1a;二进制方向安全研究&#xff1a;网络渗透方向 下面逐一说明一下。 第一个方向&#xff1a;安全研发 你可以把网络安全理解成电商行业、教育行业等其他行业一样&#xf…

centos7.9系统rabbitmq3.8.5升级为3.8.35版本

说明 本文仅适用rabbitmq为RPM安装方式。 升级准备 查看环境当前版本&#xff1a; # cat /etc/redhat-release CentOS Linux release 7.9.2009 (Core) # rabbitmqctl status Status of node rabbitmq01 ... RuntimeOS PID: 19333 OS: Linux Uptime (seconds): 58 Is under …

Ollama +Docker+OpenWebUI

1 Ollama 1.1 下载Ollama https://ollama.com/download 1.2 运行llama3 $ ollama run llama3 pulling manifest pulling 00e1317cbf74... 100% ▕███████████████████████████████████████████████████████████…

Java 运行的底层原理

Java是一种跨平台的编程语言&#xff0c;其底层原理涉及到了多个方面&#xff0c;包括Java虚拟机&#xff08;JVM&#xff09;、字节码、类加载机制、垃圾回收器等。让我们逐一深入了解Java运行的底层原理。 1. Java虚拟机&#xff08;JVM&#xff09; Java虚拟机是Java程序运…

【前端】HTML基础(1)

文章目录 前言一、什么是前端二、HTML基础1、 HTML结构1.1 什么是HTML页面1.2 认识HTML标签1.3 HTML文件基本结构1.3 标签层次结构1.4 创建html文件1.5 快速生成代码框架 三、Emmet快捷键 前言 这篇博客仅仅是对HTML的基本结构进行了一些说明&#xff0c;关于HTML的更多讲解以及…

我国烟雾报警器市场规模逐渐增长 市场集中度相对较低

我国烟雾报警器市场规模逐渐增长 市场集中度相对较低 烟雾报警器又称为烟雾探测器、烟感报警器等&#xff0c;是用于检测室内烟雾浓度、实现火灾防范的一种安全设备。烟雾报警器具有反应速度快、灵敏度高、功耗低等优点。根据工作原理不同&#xff0c;烟雾报警器可分为热敏式、…

使用Vue3开发项目,搭建Vue cli3项目步骤

1.打开cmd &#xff0c;输入 vue create neoai遇到这样的问题 则需要升级一下电脑上 Vue Cli版本哈 升级完成之后 再次输入命令&#xff0c;创建vue3项目 vue create neoai安装完成后&#xff0c;输入 npm run serve 就可以运行项目啦~ 页面运行效果