深度学习之pytorch实现线性回归

news2025/1/12 23:04:07

度学习之pytorch实现线性回归

  • pytorch用到的函数
    • torch.nn.Linearn()函数
    • torch.nn.MSELoss()函数
    • torch.optim.SGD()
  • 代码实现
  • 结果分析

pytorch用到的函数

torch.nn.Linearn()函数

torch.nn.Linear(in_features, # 输入的神经元个数
           out_features, # 输出神经元个数
           bias=True # 是否包含偏置
           )

在这里插入图片描述

作用j进行线性变换
Linear(1, 1) : 表示一维输入,一维输出

torch.nn.MSELoss()函数

在这里插入图片描述

torch.optim.SGD()

优化器对象
在这里插入图片描述

代码实现

import torch

x_data = torch.tensor([[1.0], [2.0], [3.0]])  # 将x_data设置为tensor类型数据
y_data = torch.tensor([[2.0], [4.0], [6.0]])


class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()  # 继承父类
        self.linear = torch.nn.Linear(1, 1)
        # 用torch.nn.Linear来构造对象  (y = w * x + b)
        
    def forward(self, x):
        y_pred = self.linear(x) #调用之前的构造的对象(调用构造函数),计算 y = w * x + b
        return y_pred


model = LinearModel()

criterion = torch.nn.MSELoss(size_average=False)  # 定义损失函数,不求平均损失(为False)

#优化器对象
# #model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
# #类似权重的更新
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 定义梯度优化器为随机梯度下降

for epoch in range(10000):  # 训练过程
    y_pred = model(x_data)  # 向前传播,求y_pred
    loss = criterion(y_pred, y_data)  # 根据y_pred和y_data求损失
    print(epoch, loss)
    
    # 记住在backward之前要先梯度归零
    
    optimizer.zero_grad()  # 将优化器数值清零
    loss.backward()  # 反向传播,计算梯度
    optimizer.step()  # 根据梯度更新参数


#打印权重和b
print("w = ", model.linear.weight.item())
print("b = ", model.linear.bias.item())


#检测模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred = ', y_test.data)  # 测试

结果分析

9961 tensor(4.0927e-12, grad_fn=)
9962 tensor(4.0927e-12, grad_fn=)
9963 tensor(4.0927e-12, grad_fn=)
9964 tensor(4.0927e-12, grad_fn=)
9965 tensor(4.0927e-12, grad_fn=)
9966 tensor(4.0927e-12, grad_fn=)
9967 tensor(4.0927e-12, grad_fn=)
9968 tensor(4.0927e-12, grad_fn=)
9969 tensor(4.0927e-12, grad_fn=)
9970 tensor(4.0927e-12, grad_fn=)
9971 tensor(4.0927e-12, grad_fn=)
9972 tensor(4.0927e-12, grad_fn=)
9973 tensor(4.0927e-12, grad_fn=)
9974 tensor(4.0927e-12, grad_fn=)
9975 tensor(4.0927e-12, grad_fn=)
9976 tensor(4.0927e-12, grad_fn=)
9977 tensor(4.0927e-12, grad_fn=)
9978 tensor(4.0927e-12, grad_fn=)
9979 tensor(4.0927e-12, grad_fn=)
9980 tensor(4.0927e-12, grad_fn=)
9981 tensor(4.0927e-12, grad_fn=)
9982 tensor(4.0927e-12, grad_fn=)
9983 tensor(4.0927e-12, grad_fn=)
9984 tensor(4.0927e-12, grad_fn=)
9985 tensor(4.0927e-12, grad_fn=)
9986 tensor(4.0927e-12, grad_fn=)
9987 tensor(4.0927e-12, grad_fn=)
9988 tensor(4.0927e-12, grad_fn=)
9989 tensor(4.0927e-12, grad_fn=)
9990 tensor(4.0927e-12, grad_fn=)
9991 tensor(4.0927e-12, grad_fn=)
9992 tensor(4.0927e-12, grad_fn=)
9993 tensor(4.0927e-12, grad_fn=)
9994 tensor(4.0927e-12, grad_fn=)
9995 tensor(4.0927e-12, grad_fn=)
9996 tensor(4.0927e-12, grad_fn=)
9997 tensor(4.0927e-12, grad_fn=)
9998 tensor(4.0927e-12, grad_fn=)
9999 tensor(4.0927e-12, grad_fn=)

w = 1.9999985694885254
b = 2.979139480885351e-06
y_pred = tensor([8.0000])

因为轮数过多,这里展示后面几轮
模型的准确性,跟轮数的多少有关系 ,如果轮数为100,最后测试结果的y_pred肯定不为8.00,这里轮数为10000,预测结果跟实际结果基本一样

这里是轮数为100,结果是 7点多,有一定误差
0 tensor(101.4680, grad_fn=)
1 tensor(45.8508, grad_fn=)
2 tensor(21.0819, grad_fn=)
3 tensor(10.0458, grad_fn=)
4 tensor(5.1234, grad_fn=)
5 tensor(2.9227, grad_fn=)
6 tensor(1.9338, grad_fn=)
7 tensor(1.4844, grad_fn=)
8 tensor(1.2754, grad_fn=)
9 tensor(1.1736, grad_fn=)
10 tensor(1.1195, grad_fn=)
11 tensor(1.0869, grad_fn=)
12 tensor(1.0639, grad_fn=)
13 tensor(1.0453, grad_fn=)
14 tensor(1.0288, grad_fn=)
15 tensor(1.0134, grad_fn=)
16 tensor(0.9985, grad_fn=)
17 tensor(0.9841, grad_fn=)
18 tensor(0.9699, grad_fn=)
19 tensor(0.9559, grad_fn=)
20 tensor(0.9421, grad_fn=)
21 tensor(0.9286, grad_fn=)
22 tensor(0.9153, grad_fn=)
23 tensor(0.9021, grad_fn=)
24 tensor(0.8891, grad_fn=)
25 tensor(0.8764, grad_fn=)
26 tensor(0.8638, grad_fn=)
27 tensor(0.8513, grad_fn=)
28 tensor(0.8391, grad_fn=)
29 tensor(0.8271, grad_fn=)
30 tensor(0.8152, grad_fn=)
31 tensor(0.8034, grad_fn=)
32 tensor(0.7919, grad_fn=)
33 tensor(0.7805, grad_fn=)
34 tensor(0.7693, grad_fn=)
35 tensor(0.7582, grad_fn=)
36 tensor(0.7474, grad_fn=)
37 tensor(0.7366, grad_fn=)
38 tensor(0.7260, grad_fn=)
39 tensor(0.7156, grad_fn=)
40 tensor(0.7053, grad_fn=)
41 tensor(0.6952, grad_fn=)
42 tensor(0.6852, grad_fn=)
43 tensor(0.6753, grad_fn=)
44 tensor(0.6656, grad_fn=)
45 tensor(0.6561, grad_fn=)
46 tensor(0.6466, grad_fn=)
47 tensor(0.6373, grad_fn=)
48 tensor(0.6282, grad_fn=)
49 tensor(0.6192, grad_fn=)
50 tensor(0.6103, grad_fn=)
51 tensor(0.6015, grad_fn=)
52 tensor(0.5928, grad_fn=)
53 tensor(0.5843, grad_fn=)
54 tensor(0.5759, grad_fn=)
55 tensor(0.5676, grad_fn=)
56 tensor(0.5595, grad_fn=)
57 tensor(0.5514, grad_fn=)
58 tensor(0.5435, grad_fn=)
59 tensor(0.5357, grad_fn=)
60 tensor(0.5280, grad_fn=)
61 tensor(0.5204, grad_fn=)
62 tensor(0.5129, grad_fn=)
63 tensor(0.5056, grad_fn=)
64 tensor(0.4983, grad_fn=)
65 tensor(0.4911, grad_fn=)
66 tensor(0.4841, grad_fn=)
67 tensor(0.4771, grad_fn=)
68 tensor(0.4703, grad_fn=)
69 tensor(0.4635, grad_fn=)
70 tensor(0.4569, grad_fn=)
71 tensor(0.4503, grad_fn=)
72 tensor(0.4438, grad_fn=)
73 tensor(0.4374, grad_fn=)
74 tensor(0.4311, grad_fn=)
75 tensor(0.4250, grad_fn=)
76 tensor(0.4188, grad_fn=)
77 tensor(0.4128, grad_fn=)
78 tensor(0.4069, grad_fn=)
79 tensor(0.4010, grad_fn=)
80 tensor(0.3953, grad_fn=)
81 tensor(0.3896, grad_fn=)
82 tensor(0.3840, grad_fn=)
83 tensor(0.3785, grad_fn=)
84 tensor(0.3730, grad_fn=)
85 tensor(0.3677, grad_fn=)
86 tensor(0.3624, grad_fn=)
87 tensor(0.3572, grad_fn=)
88 tensor(0.3521, grad_fn=)
89 tensor(0.3470, grad_fn=)
90 tensor(0.3420, grad_fn=)
91 tensor(0.3371, grad_fn=)
92 tensor(0.3322, grad_fn=)
93 tensor(0.3275, grad_fn=)
94 tensor(0.3228, grad_fn=)
95 tensor(0.3181, grad_fn=)
96 tensor(0.3136, grad_fn=)
97 tensor(0.3091, grad_fn=)
98 tensor(0.3046, grad_fn=)
99 tensor(0.3002, grad_fn=)
w = 1.6352288722991943
b = 0.8292105793952942
y_pred = tensor([7.3701])

Process finished with exit code 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1456587.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux防火墙:SNAT和DNAT地址转换操作

目录 一、NAT 1、NAT概念 2、NAT分类 二、SNAT 1、SNAT概念 2、SNAT源地址转换过程 3、已知外网地址的SNAT操作 3.1 配置网关服务器 3.1.1 添加网卡 3.1.2 配置ens33网卡 3.1.3 配置ens36网卡 3.1.4 重启网卡并查看网卡是否生效 3.1.5 开启路由转发功能 3.1.6 配…

Springboot医院信息管理系统源码 带电子病历和LIS Saas应用+前后端分离+B/S架构

目录 系统特点 技术架构 系统功能 1、 标准数据维护 2、 收费(门诊/住院)系统 3、 药剂管理系统 4、 医生工作站系统 5、 护士工作站系统 6、电子病历系统 系统优点 云HIS系统简介 云HIS系统功能模块 门急诊挂号管理 门诊收费管理 门诊医…

嵌入式面试:瑞芯微

文章目录 一、2024 秋招1.1 IIC的速率范围 :1.2 linux驱动子系统汇总 :1.3 linux关抢占情况汇总 :1.4 操作或者读写一个文件时,从用户态到内核态再到物理介质的流程(考点:虚拟文件系统) : 一、2024 秋招 1…

Model / View结构

红色部分是可以直接使用的。 QFileSystemModel; QFileSystemModel的使用: 头文件: QFileSystemModel* model nullptr; cpp文件: model new QFileSystemModel; model->setRootPath(QDir::currentPath()); ui->listView->setModel…

LDR6020打造最具有性价比的TYPE-C台式显示器方案

对于手里有TYPE-C接口电脑设备,又觉得自带屏幕太小,需要换用外接屏幕,或者需要多屏办公的用户。肯定要首选支持Type-c连接的显示器了。为什么呢?因为Type-c连接可以战未来,而不是仅仅能满足现在的需求。 首先介绍一下…

【vue+leaflet】vue项目中使用leaflet绘制室内平面图、leaflet.pm在平面图中绘制点、线、面图层(一)

效果图: 一,插件安装 npm i leaflet --save // 我的版本^1.9.4 npm i leaflet.pm --save // 我的版本^2.2.0附官网链接: leaflet官网: https://leafletjs.com/index.html leaflet.pm官网: https://www.npmjs.com/package/leaflet.pm?activeTabreadme 二,模块引入 因为我…

语义相关性评估指标:召回率、准确率、Roc曲线、AUC;Spearman相关系数、NDCG、mAP。代码及计算示例。

常规的语义相关性评价可以从检索、排序两个方面进行。这里只贴代码。详细可见知乎https://zhuanlan.zhihu.com/p/682853171 检索 精确率 def pre(true_labels[],pre_labels[]):""":param true_labels: 正样本索引:param pre_labels: 召回样本索引:return: 精…

nacos 2.3.1-SNAPSHOT 源码springboot方式启动(详细)附改造工程地址

文章时间是2024-2-18日,nacos默认develop分支,最新版是2.3.1-SNAPSHOT版本。 我们这里就以nacos最新版进行改造成springboot启动方式。 1. Clone 代码 nacos github地址:https://github.com/alibaba/nacos.git 根据上面git地址把源码克隆到…

wps快速生成目录及页码设置(自备)

目录 第一步目录整理 标题格式设置 插入页码(罗马和数字) 目录生成(从罗马尾页开始) ​编辑目录格式修改 第一步目录整理 1罗马标题 2罗马标题1一级标题 1.1 二级标题 1.2二级标题2一级标题 2.1 二级标题 2.2二级标题3一级标…

开源模型应用落地-工具使用篇-SLB(二)

一、前言 通过学习"开源模型应用落地"系列文章,我们成功地建立了一个完整可实施的AI交付流程。现在,我们要引入负载均衡,以提高我们的AI服务的性能和故障转移能力。本文将详细介绍如何使用腾讯云的负载均衡技术来将我们的AI服务部署…

网页脚本 bilibili004:字幕展示添加下载功能实现

效果 按钮显示 按钮hover 按钮点击 代码实现 在main函数中添加下载逻辑 getVideoName().then((resultObject) > {// 处理异步的Promise对象,https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Promise/thenaddDownloadButt…

Fiddler与wireshark使用

Fiddler解决三个问题 1、SSL证书打勾,解析https请求 2、响应回来乱码,不是中文 3、想及时中止一下,查看实时的日志 4、搜索对应的关键字 问题1解决方案: 标签栏Tools下 找到https,全部打勾 Actions里面 第一个 t…

沁恒CH32V30X学习笔记00--芯片概述

芯片概述 资源 系统框图 V303时钟树 V305/V307时钟 RISC-V4F 处理器 单精度浮点运算 处理器内部以模块化管理, 包含快速可编程中断控制器(PFIC) 内存保护 分支预测模式 扩展指令支持等单元 小端数据模式 多级硬件中断堆栈&#

ChatGPT实战100例 - (17) 用ChatGPT实现音频长度测量和音量调整

文章目录 ChatGPT实战100例 - (17) 用ChatGPT实现音频长度测量和音量调整获取音频长度pydub获取音频长度获取时长精确到秒格式设定 mutagen获取音频长度 调整音量视频音量调整注意事项 ChatGPT实战100例 - (17) 用ChatGPT实现音频长度测量和音量调整 老王媳妇说上次那个pip挺好…

『运维备忘录』之 SSH 命令详解

运维人员不仅要熟悉操作系统、服务器、网络等知识,甚至对于开发相关的也要有所了解。很多运维工作者可能一时半会记不住那么多命令、代码、方法、原理或者用法等等。这里我将结合自身工作,持续给大家更新运维工作所需要接触到的知识点,希望大…

OpenAI最新模型Sora到底有多强?眼见为实的真实世界即将成为过去!

文章目录 1. 写在前面2. 什么是Sora?3. Sora的技术原理 【作者主页】:吴秋霖 【作者介绍】:Python领域优质创作者、阿里云博客专家、华为云享专家。长期致力于Python与爬虫领域研究与开发工作! 【作者推荐】:对JS逆向感…

Linux-文件文件夹相关命令

目录 常见命令 1. 创建空目录:mkdir 文件夹名 2. 删除空目录:rmdir 文件夹名 3. 创建多级目录:mkdir -p 123/abc 4. 删除非空文件 rm -rf 文件夹名 5. 创建文件: touch 文件名.后缀 / vi 文件名.后缀 6. 删除文件&#x…

挑战杯 基于GRU的 电影评论情感分析 - python 深度学习 情感分类

文章目录 1 前言1.1 项目介绍 2 情感分类介绍3 数据集4 实现4.1 数据预处理4.2 构建网络4.3 训练模型4.4 模型评估4.5 模型预测 5 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 基于GRU的 电影评论情感分析 该项目较为新颖,适合作为竞…

Spring Boot java -jar --spring.profiles.active=dev 失效问题

之前动态部署修改配置文件的情况不多&#xff0c;所以也没注意过&#xff0c;这个问题今天困扰了好久&#xff0c;经过多方查询后得到了解决办法 直接上代码 <profiles><profile><!-- 本地开发环境 --><id>dev</id><properties><profi…

Codeforces Global Round 6

CF1266A Competitive Programmer 题目 给出n个数,问对于每个数,是否可以将这个数的数位重新组合(可以有前导零), 使其可以被60整除,若可以,则输出red,否则,输出cyan 分析 首先来看被60整除需要满足什么条件&#xff0c;因为602*3*10&#x…