PyTorch - 线性回归

news2025/1/23 11:28:52


普通实现

准备数据

import torch 
import matplotlib.pyplot as plt 

# 1、准备数据
# y = 2 * x + 0.8
x = torch.rand([500, 1])
y_true = 2 * x + 0.8 


# 2、通过模型计算 y_predict
w = torch.rand([1, 1], requires_grad=True) 
b = torch.tensor(0, requires_grad=True, dtype=torch.float32) 
 
learning_rate = 0.01

反向传播

# 4、通过循环,反向传播,更新参数
for i in range(1000):

    y_predict = torch.matmul(x, w) + b  # 预测值
    
    # 3、计算loss
    # 回归问题,使用均方误差 
    loss = (y_true - y_predict).pow(2).mean()

    if w.grad is not None:
        w.grad.data.zero_() 

    if b.grad is not None:
        b.grad.data.zero_() 

    loss.backward() # 反向传播
    
    w.data = w.data - learning_rate * w.grad 
    b.data = b.data - learning_rate * b.grad 

    if i%100 == 1:
        print(f'-- i: {i}, w : {w.item()}, b : {b.item()}, loss : {loss}, ')
  
print(f'-- end w : {w.item()}, b : {b.item()}, loss : {loss}, ')

-- i: 1, w : 0.35467997193336487, b : 0.06384684145450592, loss : 2.716614007949829, 
-- i: 101, w : 1.0814398527145386, b : 1.1525230407714844, loss : 0.08122585713863373, 
-- i: 201, w : 1.238681674003601, b : 1.1865873336791992, loss : 0.049526866525411606, 
-- i: 301, w : 1.3382785320281982, b : 1.1438771486282349, loss : 0.03770563006401062, 
-- i: 401, w : 1.4222863912582397, b : 1.1008449792861938, loss : 0.028764590620994568, 
-- i: 501, w : 1.4954265356063843, b : 1.0628066062927246, loss : 0.021944044157862663, 
-- i: 601, w : 1.5592910051345825, b : 1.029546856880188, loss : 0.016740744933485985, 
-- i: 701, w : 1.61506986618042, b : 1.0004940032958984, loss : 0.012771294452250004, 
-- i: 801, w : 1.6637893915176392, b : 0.9751182794570923, loss : 0.009743028320372105, 
-- i: 901, w : 1.7063422203063965, b : 0.9529542922973633, loss : 0.007432833779603243, 

-- end w : 1.7428138256072998, b : 0.9339574575424194, loss : 0.005701201036572456, 

plt.figure(figsize=(20, 8))
plt.scatter(x.numpy().reshape(-1), y_true.numpy().reshape(-1)) 

y_predict = torch.matmul(x, w) + b  
plt.plot(x.numpy().reshape(-1), y_predict.detach().numpy().reshape(-1), c='r') 
plt.show() 

在这里插入图片描述


测试

x1 = torch.tensor(2)
y1 = w * x1 + b 
x1, y1  

(tensor(2), tensor([[4.4196]], grad_fn=<AddBackward0>))

构建模型 实现

import torch 
from torch import nn
from torch import optim # 优化器

构建数据

x = torch.rand([500, 1])
y_true = 2 * x + 0.8 

构建模型
1、需要调用super方法,继承父类的属性和方法
2、必须实现 forward 方法,用来定义我们网络的前向计算过程。

class LR(nn.Module):
    def __init__(self):
        super(LR, self).__init__() 
        
        # 预定义好的线性模型,也被称为 全链接层
        # 传入的参数为输入的数量、输出的数量(in_features, out_features);是不算 batch_size 的列数。
        # nn.Module 定义了 __call__ 方法,实现的就是 调用 forward 方法,即 LR 的实例,能够直接被传入参数调用;实际调用的是 forward 方法。
        self.linear = nn.Linear(1, 1)  
        
    def forward(self, x):
        out = self.linear(x)
        return out 
     



实例化模型、损失函数、优化器

model = LR() # 实例化模型
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
'''
SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    lr: 0.001
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)
''' 

优化器类 optimizer,可以理解为 torch 为我们封装的 用来更新参数的方法。
比如:常见的随机梯度下降(stochastic gradient descent, SGD) 优化器类都是由 torch.optim 提供的,例如:

  • torch.optim.SGD
  • torch.optim.Adam

注意, 1、可以使用 model.parameters() 来获取参数;获取模型中 所有 requires_grad = True 的参数;
2、优化类的使用发发: 1)实例化; 2)所有参数的梯度,将其值置为0; 3)反向传播计算梯度; 4)更新参数值


训练数据

for i in range(10000):
    # 传入数据,计算结果
    y_predict = model(x)
    # y_predict = model(x_true) # 向前计算预测值
    loss = criterion(y_true, y_predict)  # 调用损失函数,传入真实值和预测值,得到损失结果
    optimizer.zero_grad() # 当前循环参数梯度置为 0 
    loss.backward() # 计算梯度
    optimizer.step() # 更新参数的值
    
    if i%1000 == 0:
        print(f'-- {i} {loss.data}') 
    

-- 0 5.0724568367004395
-- 1000 0.2791365087032318
-- 2000 0.19754908978939056
-- 3000 0.1544567495584488
-- 4000 0.12085574865341187
-- 5000 0.0945650264620781
-- 6000 0.07399351894855499
-- 7000 0.05789710581302643
-- 8000 0.04530230164527893
-- 9000 0.035447314381599426

评估模型

model.eval()

predict = model(x)  # 设置模型为评估模式,即预测模式
predict = predict.data.numpy()   


plt.scatter(x.data.numpy(), y_true.data.numpy(), c='r')
plt.plot(x.data.numpy(), predict) 
plt.show()

在这里插入图片描述


伊织 2022-12-10

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/77832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MyBatis面试题(2022最新版)

整理好的MyBatis面试题库&#xff0c;史上最全的MyBatis面试题&#xff0c;MyBatis面试宝典&#xff0c;特此分享给大家 MyBatis简介 MyBatis是什么&#xff1f; MyBatis 是一款优秀的持久层框架&#xff0c;一个半 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;它…

Kotlin 开发Android app(二十一):协程launch

什么是协程&#xff0c;这可是这几年才有的概念&#xff0c;我们也不用管它是什么概念&#xff0c;先看看他能做什么。 创建协程 添加依赖&#xff1a; implementation org.jetbrains.kotlinx:kotlinx-coroutines-core:1.3.9implementation org.jetbrains.kotlinx:kotlinx-cor…

DCDC电感下方铜箔如何处理

挖&#xff1a;电感在工作时&#xff0c;其持续变化的电流产生的电磁波会或多或少的泄露出来&#xff0c;电感下方的铜箔受电磁波影响&#xff0c;就会有涡流出现&#xff0c;这个涡流&#xff0c;①可能对线路板上的信号线有干扰&#xff0c;②铜箔内的涡流会产生热量&#xf…

申请阿里云域名SSL证书步骤

1.【点击登录】 阿里云 2.选择 DV单域名证书 3.确定购买&#xff0c;支付。 4.完成后&#xff0c;跳转回控制台。 5.点击 证书申请。 6.填写域名、申请人姓名、手机号、邮箱、所在地 7、选择域名验证方式&#xff0c;官方提供了三种验证方式&#xff0c;根据自身情况选择其中…

【Linux】Linux的常见指令详解(下)

目录 前言 head/tail 命令行管道 date sort cal 搜索指令 find which whereis alias grep zip tar file bc history 热键 前言 之前讲了Linux的常见指令详解&#xff08;上&#xff09;&#xff0c;这次终于把下也补齐了。如果对你有帮助还麻烦给博主一个…

Netty_05_六种序列化方式(JavaIO序列化 XML序列化 Hessian序列化 JSON序列化 Protobuf序列化 AVRO序列化)(实践类)

文章目录一、普通的序列化方式(bean对象有直接的java类)1.1 普通的java io byteArray输入输出流的序列化方式1.2 xml序列化方式&#xff08;xml用来做配置文件&#xff0c;这样序列化出来长度很大&#xff09;1.3 Hessian序列化方式&#xff08;这个Dubbo中使用的序列化方式&am…

flask前后端项目--实例-前端部分:-3-vue基本配置

一、基本配置以及验证 1.基础环境&#xff1a;nodejs的安装配置以及注意事项 https://blog.csdn.net/wtt234/article/details/128131999 2.vue使用vite创建文件包的过程 创建项目 npm init vitelatest 根据提示一步步选择&#xff1a; 选择vue 进入项目目录&#xff0c;安装…

【计算机网络】网络层:IPV6

IPV4耗尽&#xff0c;使用具有更多地址空间的IPV6 IPV6特点&#xff1a; (1)IPV6地址128位&#xff0c;更大地址空间&#xff0c;可以划分位更多的层次 (2)IPV6定义许多拓展首部&#xff0c;可提供更多功能&#xff0c;但IPV6首部长度固定&#xff0c;选项放在有效载荷中 (…

打败阿根廷的究竟是谁

2022年卡塔尔世界杯正在如火如茶的进行着。在今年的世界杯中&#xff0c;有两个令人意外的点&#xff0c;一个是日本队击败的德国队&#xff0c;另外一点是沙特队战胜了实力强盛的阿根廷队。 有人说打败阿根廷队的不是沙特队&#xff0c;而是科技------"半自动越位"技…

某Y易盾滑块acToken、data逆向分析

内容仅供参考学习 欢迎朋友们V一起交流&#xff1a; zcxl7_7 目标 网址&#xff1a;案例地址 这个好像还没改版&#xff0c;我看官网体验那边已经进行了混淆 只研究了加密的生成&#xff0c;环境不正确可能会导致的加密结果对 (太累了&#xff0c;先缓缓吧&#xff0c;最近事比…

创建Mongo官方的免费数据库并使用VSCode连接

注册账号 https://cloud.mongodb.com/ 在这个平台注册账号&#xff0c;并登录 创建数据库 选择shared&#xff0c;其他要收费 导入样本数据 导入后会发现数据中多了很多sample数据&#xff0c;用于练习 创建访问用户 允许任何地址访问 如果需要任何IP地址都能访问&#xff…

Qt-数据库开发-用户登录、后台管理用户(6)

Qt-数据库开发-使用QSqlite数据库实现用户登录、后台管理用户功能 文章目录Qt-数据库开发-使用QSqlite数据库实现用户登录、后台管理用户功能1、概述2、实现效果3、主要代码4、完整源代码更多精彩内容&#x1f449;个人内容分类汇总 &#x1f448;&#x1f449;数据库开发 &…

UG环境设置

UG环境设置UG设置工作路径默认设置方法1&#xff1a; 修改快捷键路径方法2&#xff1a;修改“用户默认设置”UG设置窗口标题效果方法注意设置十字准线效果设置方法角色设置窗口布局效果方法命令搜索UG设置工作路径 默认设置 打开NX软件&#xff0c;新建模型默认路径如下&…

代码随想录算法训练营第三天| 链表理论基础, 203.移除链表元素,707.设计链表,206.反转链表

代码随想录算法训练营第三天| 链表理论基础&#xff0c; 203.移除链表元素&#xff0c;707.设计链表&#xff0c;206.反转链表 链表理论基础 建议&#xff1a;了解一下链接基础&#xff0c;以及链表和数组的区别 文章链接&#xff1a; 203.移除链表元素 建议&#xff1a; 本…

智源社区AI周刊No.109:ChatGPT预示大模型取代搜索引擎;Stable Diffusion2.1发布,8k高清图像生成...

汇聚每周AI热点&#xff0c;不错过重要资讯&#xff01;欢迎扫码&#xff0c;关注并订阅智源社区AI周刊。ChatGPT火出圈&#xff1a;对话大模型驱动新型搜索范式诞生&#xff0c;或将取代搜索引擎火出圈的ChatGPT注册用户数量已超过五百万&#xff0c;无疑是2022年最火的AI模型…

ReactNative MacOS环境初始化项目(安卓)

MacOS 12.6.1 官方文档 英文 https://reactnative.dev/docs/environment-setup中文 https://www.react-native.cn/docs/environment-setup 相关文档 ReactNative MacOS环境初始化项目(ios)OpenJDK 与 AdoptOpenJDK 的区别 安装步骤 安装Homebrew - /bin/zsh -c "$(curl -f…

spring学习记录(七)

Spring中对象分类 Spring是一个功能强大的容器&#xff0c;容器中存储的是一个一个的对象&#xff0c;容器中的对象分为&#xff1a; 简单对象复杂对象 简单对象就是可以通过构造器直接new 出来的对象&#xff1b; 复杂对象是不可以直接通过构造器直接new出来的对象。 无论是…

[附源码]Python计算机毕业设计SSM基于课程群的实验管理平台(程序+LW)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Java中异常处理方式

文章目录针对异常的处理主要有两种方式&#xff1a;1.抛出异常2.try catch 捕获异常三道经典异常处理代码题针对异常的处理主要有两种方式&#xff1a; 1.抛出异常 遇到异常不进行具体处理&#xff0c;而是继续抛给调用者&#xff08;throw&#xff0c;throws&#xff09;抛出…

java8新特性之toMap的用法——全网独一无二的通俗易懂的讲解

对于java8的新特性toMap方法&#xff0c;相信有很多人都在工作中用过&#xff0c;接下来就通俗易懂的讲解一下toMap吧 先来看看官网对于toMap方法的解释 toMap有个三个重载的方法&#xff0c;每一个重载方法的详解分别如下 &#xff08;1&#xff09;方法1&#xff1a;两个参…