【深度学习入门篇 ②】Pytorch完成线性回归!

news2025/2/28 3:14:00

🍊,大家好,我是小森( ﹡ˆoˆ﹡ )! 易编橙·终身成长社群创始团队嘉宾橙似锦计划领衔成员阿里云专家博主腾讯云内容共创官CSDN人工智能领域优质创作者 。

易编橙:一个帮助编程小伙伴少走弯路的终身成长社群!


上一部分我们自己通过torch的方法完成反向传播和参数更新,在Pytorch中预设了一些更加灵活简单的对象,让我们来构造模型、定义损失,优化损失等;那么接下来,我们一起来了解一下其中常用的API!

nn.Module

nn.Module 是 PyTorch 框架中用于构建所有神经网络模型的基类。在 PyTorch 中,几乎所有的神经网络模块(如层、卷积层、池化层、全连接层等)都继承自 nn.Module。这个类提供了构建复杂网络所需的基本功能,如参数管理、模块嵌套、模型的前向传播等。

当我们自定义网络的时候,有两个方法需要特别注意:

  1. __init__需要调用super方法,继承父类的属性和方法

  2. farward方法必须实现,用来定义我们的网络的向前计算的过程

用前面的y = wx+b的模型举例如下:

from torch import nn
class Lr(nn.Module):
    def __init__(self):
        super(Lr, self).__init__()  # 继承父类init的参数
        self.linear = nn.Linear(1, 1) 

    def forward(self, x):
        out = self.linear(x)
        return out
  •  nn.Linear为torch预定义好的线性模型,也被称为全链接层,传入的参数为输入的数量,输出的数量(in_features, out_features),是不算(batch_size的列数)
  • nn.Module定义了__call__方法,实现的就是调用forward方法,即Lr的实例,能够直接被传入参数调用,实际上调用的是forward方法并传入参数
  • __init__方法里面的内容就是类创建的时候,跟着自动创建的部分。
  • 与之对应的就是__del__方法,在对象被销毁时执行一些清理操作。
# 实例化模型
model = Lr()
# 传入数据,计算结果
predict = model(x)

优化器类

优化器(optimizer),可以理解为torch为我们封装的用来进行更新参数的方法,比如常见的随机梯度下降(stochastic gradient descent,SGD)。

优化器类都是由torch.optim提供的,例如

  1. torch.optim.SGD(参数,学习率)

  2. torch.optim.Adam(参数,学习率)

注意:

  • 参数可以使用model.parameters()来获取,获取模型中所有requires_grad=True的参数


optimizer = optim.SGD(model.parameters(), lr=1e-3) # 实例化
optimizer.zero_grad() # 梯度置为0
loss.backward() #  计算梯度
optimizer.step()  #  更新参数的值

损失函数

  1. 均方误差:nn.MSELoss(),常用于回归问题

  2. 交叉熵损失:nn.CrossEntropyLoss(),常用于分类问题

model = Lr() # 实例化模型
criterion = nn.MSELoss() # 实例化损失函数
optimizer = optim.SGD(model.parameters(), lr=1e-3) #  实例化优化器类
for i in range(100):
    y_predict = model(x_true) # 预测值
    loss = criterion(y_true,y_predict) # 调用损失函数传入真实值和预测值,得到损失
    optimizer.zero_grad() 
    loss.backward() # 计算梯度
    optimizer.step()  # 更新参数的值

线性回归代码!

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt


x = torch.rand([50,1])
y = x*3 + 0.8

# 自定义线性回归模型
class Lr(nn.Module):
    def __init__(self):
        super(Lr,self).__init__()
        self.linear = nn.Linear(1,1)
		
    def forward(self, x):    # 模型的传播过程
        out = self.linear(x)
        return out

# 实例化模型,loss,和优化器
model = Lr()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
#训练模型
for i in range(30000):
    out = model(x) 
    loss = criterion(y,out) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step()  
    if (i+1) % 20 == 0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(i,30000,loss.data))

# 模型评估模式,之前说过的
model.eval() 
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()

  • 可以看出经过30000次训练后(相当于看书,一遍遍的回归学习),基本就可以拟合预期直线了

GPU上运行代码

当模型太大,或者参数太多的情况下,为了加快训练速度,经常会使用GPU来进行训练

此时我们的代码需要稍作调整:

1.判断GPU是否可用torch.cuda.is_available()

torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device(type='cuda', index=0)  # 使用GPU

2.把模型参数和input数据转化为cuda的支持类型

model.to(device)
x.to(device)

3.在GPU上计算结果也为cuda的数据类型,需要转化为numpy或者torch的cpu的tensor类型

predict = predict.cpu().detach().numpy() 
  • predict.cpu() predict张量从可能的其他设备(如GPU)移动到CPU上
  • predict.detach() .detach()方法会返回一个新的张量,这个张量不再与原始计算图相关联,即它不会参与后续的梯度计算。
  • .numpy()方法将张量转换为NumPy数组。

 GPU代码:

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import time

x = torch.rand([50,1])
y = x*3 + 0.8

class Lr(nn.Module):
    def __init__(self):
        super(Lr,self).__init__()
        self.linear = nn.Linear(1,1)

    def forward(self, x):
        out = self.linear(x)
        return out


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x,y = x.to(device),y.to(device)

model = Lr().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)


for i in range(300):
    out = model(x)
    loss = criterion(y,out)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i+1) % 20 == 0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(i,30000,loss.data))


model.eval() 
predict = model(x)
predict = predict.cpu().detach().numpy() 
plt.scatter(x.cpu().data.numpy(),y.cpu().data.numpy(),c="r")
plt.plot(x.cpu().data.numpy(),predict,)
plt.show()

💯常见的优化算法

在大多数情况下,我们关注的是最小化损失函数,因为它衡量了模型预测与真实标签之间的差异。

梯度下降算法(batch gradient descent BGD)

每次迭代都需要把所有样本都送入,这样的好处是每次迭代都顾及了全部的样本,做的是全局最优化,但是有可能达到局部最优。

随机梯度下降法 (Stochastic gradient descent SGD)

针对梯度下降算法训练速度过慢的缺点,提出了随机梯度下降算法,随机梯度下降算法算法是从样本中随机抽出一组,训练后按梯度更新一次,然后再抽取一组,再更新一次,在样本量及其大的情况下,可能不用训练完所有的样本就可以获得一个损失值在可接受范围之内的模型了。

小批量梯度下降 (Mini-batch gradient descent MBGD)

SGD相对来说要快很多,但是也有存在问题,由于单个样本的训练可能会带来很多噪声,使得SGD并不是每次迭代都向着整体最优化方向,因此在刚开始训练时可能收敛得很快,但是训练一段时间后就会变得很慢。在此基础上又提出了小批量梯度下降法,它是每次从样本中随机抽取一小批进行训练,而不是一组,这样即保证了效果又保证的速度。

AdaGrad

AdaGrad算法就是将每一个参数的每一次迭代的梯度取平方累加后在开方,用全局学习率除以这个数,作为学习率的动态更新,从而达到自适应学习率的效果

Adam

Adam(Adaptive Moment Estimation)算法是将Momentum算法和RMSProp算法结合起来使用的一种算法,能够达到防止梯度的摆幅多大,同时还能够加开收敛速度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1918100.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WEB攻防-通用漏洞SQL注入-ACCESS一般注入与偏移注入

ACCESS数据库 Access数据库是Microsoft Office套件中一款小型关系型数据库管理系统 单个数据库文件:Access数据库通常以一个单独的文件形式存在(如.accdb或旧版本的.mdb文件),这个文件包含了数据库的所有对象,如表、字…

确保智慧校园安全,充分利用操作日志功能

智慧校园基础平台系统的操作日志功能是确保整个平台运行透明、安全及可追溯的核心组件。它自动且详尽地记录下系统内的每一次关键操作细节,涵盖操作的具体时间、执行操作的用户账号、涉及的数据对象(例如学生信息更新、课程调度变动等)、操作…

Windows 桌面改造小技巧 · 一键去除快捷方式小箭头和小盾牌

Windows的桌面上,总会有一些不如意的小地方,比如快捷方式上的小箭头和小盾牌图标: 标志挡住了应用图标,显得很难受 这些角标作用如下: 快捷方式角标是用来提示你这是一个快捷方式的,其实这个角标还好&…

充气膜游泳馆安全吗—轻空间

充气膜游泳馆,作为一种新型的游泳场馆,以其独特的结构和众多优点,逐渐受到各地体育设施建设者的青睐。然而,关于充气膜游泳馆的安全性,一些人仍然心存疑虑。那么,充气膜游泳馆到底安全吗?轻空间…

进程的状态和优先级

一.进程的退出 进程内核PCB数据和代码,它们都要占据内存空间,进程退出的核心工作之一就是释放掉自己的PCB数据和代码。 为什么要创建出进程呢?一定是我要进程完成某些任务! 当进程退出了,不光光只是退出了这么简单&…

巢链接小程序正式上线!全面开启AI共创新时代

巢链接小程序正式上线!全面开启AI共创新时代 🚀 今天我们官宣,巢链接小程序正式上线啦!这是一款致力于打造连接开发者、创业者和运营者的全新平台,旨在通过共创和分享实现共同成长。快来了解巢链接小程序的诸多功能&a…

普中51单片机:定时器与计数器详解及应用(七)

文章目录 引言定时器工作原理TMOD定时器/计数器工作模式寄存器定时器工作模式模式0(13位定时器/计数器)模式1(16位定时器/计数器)模式2(8位自动重装模式)模式3(两个8位计数器) 定时器配置流程代码演示——LED1间隔1秒闪烁代码演示——按键1控制LED流水灯状态代码演示——LCD160…

抖音机构号授权矩阵系统源码:打造自媒体帝国的新利器

在自媒体风起云涌的时代,抖音作为短视频领域的佼佼者,早已成为内容创作者们争相入驻的热门平台。然而,随着竞争加剧,如何在这场流量大战中脱颖而出,成为每一位自媒体人不得不面对的课题。今天,我们将带您深…

JVM:类加载器

文章目录 一、什么是类加载器 一、什么是类加载器 类加载器(ClassLoader)是Java虚拟机提供给应用程序去实现获取类和接口字节码数据的技术。类加载器只参与加载过程总的字节码获取并加载到内存这一部分。

【JavaWeb程序设计】JavaBean(一)

目录 一、<jsp:useBean>、<jsp:setProperty>、<jsp:getProperty>的使用 1. 运行截图 2. UserBean.java 3. login.html 4. display.jsp 二、设计求三角形面积 1. 运行截图 2. 设计View&#xff08;inputTriangle.jsp&#xff09; 3. 设计Model&#…

AI发展到现在,国内大模型行业还有哪些机会?

随着各种AI工具的发展&#xff0c;AI领域的竞争格局正在悄然变化。GPT-4o被许多人误解为重大突破的更新&#xff0c;实际上主要是在多模态功能上的增强&#xff0c;而核心智能水平似乎仍停留在GPT-4阶段。这一现象为我们思考AI发展路径和国内大模型行业的机遇提供了新的视角。 …

运维Tips | Ubuntu 24.04 安装配置 xrdp 远程桌面服务

[ 知识是人生的灯塔,只有不断学习,才能照亮前行的道路 ] Ubuntu 24.04 Desktop 安装配置 xrdp 远程桌面服务 描述:Xrdp是一个微软远程桌面协议(RDP)的开源实现,它允许我们通过图形界面控制远程系统。这里使用RDP而不是VNC作为远程桌面,是因为Windows自带的远程桌面连接软…

软件缺陷简介

缺陷种类 遗漏&#xff0c;指规定或预期的需求为体现在产品种错误&#xff0c;需求是明确的&#xff0c;在实现阶段未将需求的功能正确实现冗余&#xff0c;需求说明文档中未涉及的需求被实现了不满意&#xff0c;用户对产品的实现不满意也成为缺陷 缺陷等级划分 致命&#…

Qt QWebSocket网络编程

学习目标&#xff1a;Qt QWebSocket网络编程 学习前置环境 QT TCP多线程网络通信-CSDN博客 学习内容 WebSocket是一种通过单个TCP连接提供全双工通信信道的网络技术。2011年&#xff0c;IETF将WebSocket协议标准化为 RFC6455&#xff0c;QWebSocket可用于客户端应用程序和服…

金龙鱼:只是躺枪?

中储粮罐车运输油罐混用事件持续发酵&#xff0c;食用油板块集体躺枪。 消费者愤怒的火&#xff0c;怕是会让食用油企们一点就着。 今天&#xff0c;我们聊聊“油”茅——金龙鱼。 一边是业内人士指出&#xff0c;油罐混用的现象普遍存在&#xff0c;另一边是金龙鱼回应称&am…

Mac虚拟机跑Windows流畅吗 Mac虚拟机连不上网络怎么解决 mac虚拟机网速慢怎么解决

随着技术的发展&#xff0c;很多用户希望能在Mac电脑上运行Windows系统&#xff0c;从而能够使用那些仅支持Windows系统的软件。使用虚拟机软件可以轻松满足这一需求。但是&#xff0c;很多人可能会有疑问&#xff1a;“Mac虚拟机跑Windows流畅吗&#xff1f;”&#xff0c;而且…

3SRB5016-ASEMI逆变箱专用3SRB5016

编辑&#xff1a;ll 3SRB5016-ASEMI逆变箱专用3SRB5016 型号&#xff1a;3SRB5016 品牌&#xff1a;ASEMI 封装&#xff1a;SGBJ-5 批号&#xff1a;2024 现货&#xff1a;50000 最大重复峰值反向电压&#xff1a;1600V 最大正向平均整流电流(Vdss)&#xff1a;50A 功…

懂点技术就可以做,适合程序员的一种生意思路|在FlowUs记录成长 发布知识库

你们是否经常在闲暇时刷刷手机&#xff0c;看看视频&#xff0c;打发时间呢&#xff1f;其实&#xff0c;这些零散的时间完全可以被利用起来&#xff0c;成为你们财富增长的源泉。下面我将分享一种适合程序员的生意思路&#xff0c;让你们用技术的力量&#xff0c;将知识转化为…

星环科技推出语料开发工具TCS,重塑语料管理与应用新纪元

5月30-31日&#xff0c;2024向星力未来数据技术峰会期间&#xff0c;星环科技推出一款创新的语料开发工具——星环语料开发工具TCS&#xff08;Transwarp Corpus Studio&#xff09;&#xff0c;旨在通过全面的语料生命周期管理&#xff0c;极大提升语料开发效率&#xff0c;助…

什么是业务架构、数据架构、应用架构和技术架构

TOGAF(The Open Group Architecture Framework)是一个广泛应用的企业架构框架&#xff0c;旨在帮助组织高效地进行架构设计和管理。而TOGAF的核心就是由我们熟知的四大架构领域组成&#xff1a;业务架构、数据架构、应用架构和技术架构。 所以今天我们就来聊聊&#xff0c;企业…