pytorch入门7--自动求导和神经网络

news2025/1/10 9:36:19

深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。
下面是课程笔记。
一、自动求导
举例说明自动求导。
torch中的张量有两个重要属性:data(值)和grad(梯度),当我们在定义一个张量时设requires_grad=True就是说明后续可以使用自动求导机制。
在这里插入图片描述
注意:pytorch里可以设置为自动求导的张量的元素需要是浮点型。
例如,对于e=(a + b) * (b + 1),可以用一个图表示如下:
在这里插入图片描述
我们定义张量时通常是从下往上定义,即先定义张量a,b,再定义张量e(由张量a,b的关系式组成),这样张量e的值就由a,b得到,这就是前向传播(前馈),通常定义为forward函数:
在这里插入图片描述
当我们要进行求导时,求:
在这里插入图片描述
在这里插入图片描述
可以看出,求导是从上到下的,逐级相乘再将路径相加。比如求e对b的偏导数,,从b到e有两条路径,每条路径从e开始逐级求导,结果相乘再将多条路径求导结果相加,这个过程加反向传递(反馈),通过pytorch封装好的backward函数实现。
下面的图比我手绘的应该清楚一些:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码实现如下(以线性逻辑回归为例,y = w * x,给定训练数据集x,y,求最佳参数w拟合x与y的关系函数):
我们直到在深度学习中,我们都是将损失函数对参数求导,使用梯度下降法等方法使得损失函数最好,从而找到参数的最佳值。

# 训练数据集(人眼可以一下看出y=2*x是最好的拟合,但机器不知道,要一直训练
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0] 

# 参数w,初始值设为1.0
w = torch.tensor([1.0],requires_grad=True)

# 前向传播
def forward(x):
    return x * w

# 损失函数
def loss(x,y):
    y_pred = forward(x) # y_pred即为常说的y_hat,是y在当前w的值下计算的估计值,这里即建立了y_pred与w的关系,可以自动求导
    return (y_pred - y) ** 2

# 训练数据集(梯度下降法)
for epoch in range(1000):
    for x,y in zip(x_data,y_data):
        l = loss(x,y)
        l.backward() # 自动求导,l对w求导,反向传播
        print('\tgrad:',x,y,w,w.grad.item()) # item()用于只含一个元素的tensor中提取值
        w.data = w.data - 0.01 * w.grad.data # **这里使用data属性就是为了防止使用自动求导机制**
        w.grad.data.zero_() # 将上一轮的梯度值清除
    print("progress:",epoch,l.item())

# 测试结果
print("predict(after training)",4,forward(4).item()) # 计算当x=4时,根据训练出的模型求y的估计值
    

在这里插入图片描述
在这里插入图片描述
可以看出w的值一直在增加,直到加到2可以完全拟合训练集中x与y的关系,最后当x等于4时, 估计值接近8.

二、神经网络
在这里插入图片描述
我们知道,神经网络由多层组成,包括输入层、隐含层和输出层,每一层的包含不同个数的结点,每层的结点其实就是当前我们获得的数据的特征值(features),例如输入层(x1,x2,x3,x4,x5)有五个结点分别表示五个特征值,第一层的隐藏层有六个结点,这是就需要一个6 * 5的矩阵w将x的5个特征值转变为6个特征值。当然也可以添加偏置值b如下图所示:
在这里插入图片描述
在这里插入图片描述
而这个矩阵w就是我们要训练出包含着某种关系的参数矩阵,再一层一层的变换,每层都有一个参数矩阵,最终到达输出矩阵的四个特征,即(y1,y2,y3,y4)。
为了是我们的神经网络模型更好地拟合非线性函数关系,还可以使用激活函数:
在这里插入图片描述
激活函数前面的文章讲过,这里不再说了。使用sigmoid激活函数如下:
在这里插入图片描述
代码实现:
1.单隐藏层的神经网络模型
pytorch对于神经网络的代码封装得很好。

# 训练数据集
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])

# 定义一个单隐藏层得神经网络
class LinearModel(torch.nn.Module):# 神经网络的类必须继承类Module
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1,1) # torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1。
        # torch.nn.Linear()第三个参数是bias,设置为True即含有偏置值b,为False不适用偏置值,默认值为True。
        
    def forward(self,x):
        y_pred = self.linear(x) # 使用封装好的linear()计算y的预测值
        return y_pred

# 生成神经网络的模型
model = LinearModel()

# 损失函数
criterion = torch.nn.MSELoss(size_average=False) # size_average=False表示损失函数不求平均值

# 优化器(梯度下降)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) # model.parameters()可以获取神经网络中的所有参数参数矩阵的值,对其进行优化
# lr表示步长
        
# 训练数据集
for epoch in range(100):
    y_pred = model(x_data) # 将数据传入搭建好的神经网络模型得到估计值
    loss = criterion(y_pred,y_data) # 计算损失值
    print(epoch,loss)
    optimizer.zero_grad() # 清除上次的梯度值
    loss.backward() # 自动求导
    optimizer.step() # 优化参数
    
# 输出结果
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())

# 测试模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred=',y_test.data)

在这里插入图片描述
在这里插入图片描述
说明:这里的torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1,其它情况使用情况如下:
在这里插入图片描述

可以看出,训练100轮效果不佳,可以训练1000次看看不同结果。

2.多隐藏层的神经网络模型
与单隐藏层神经网络模型区别如下:

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = torch.nn.Linear(8,6) # 模型从8维变为6维,再从6维变为4维,再从4维变为1维
        self.linear2 = torch.nn.Linear(6,4)
        self.linear1 = torch.nn.Linear(4,1)
        self.sigmoid = torch.nn.Sigmoid() # 使用sigmoid激活函数
        
    def forward(self,x):
        pred1 = self.sigmoid(self.linear1(x)) # 上一层输出结果传给下一层
        pred2 = self.sigmoid(self.linear2(pred1))
        y_pred = self.sigmoid(self.linear3(pred2))
        return n
    
model = Model()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/385153.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python 数据库连接 + 创建库表+ 插入【内含代码实例】

人生苦短 我用python Python其他实用资料:点击此处跳转文末名片获取 数据库连接 连接数据库前,请先确认以下事项: 您已经创建了数据库 TESTDB.在TESTDB数据库中您已经创建了表 EMPLOYEEEMPLOYEE表字段为 FIRST_NAME, LAST_NAME, AGE, SEX 和 INCOME。连…

前端css整理

如何水平垂直居中一个盒子? 1.已知高度:子盒子设置 display: inline-block; 父盒子设置 line-height 等于高度实现垂直居中;使用 text-align:center实现水平居中 2.父盒子 display:flex; align-items:center;justify-content:center; 3.定位&…

自动驾驶决策规划-控制方向学习资料总结(附相关资料的链接)

项目仓库 欢迎访问我的Github主页 项目名称说明chhCpp学习C仓库chhRobotics学习自动驾驶、控制理论相关仓库(python实现)chhRobotics_CPP学习自动驾驶、控制理论相关仓库(c实现)chhML 、chh-MachineLearning学习机器学习仓库chhRL学习强化学习仓库chhTricks存放一些有意思的t…

SpringSecurity的初次邂逅

【第一篇】SpringSecurity的初次邂逅 1.Spring Security概念 Spring Security是Spring采用 AOP思想,基于 servlet过滤器实现的安全框架。它提供了完善的认证机制和方法级的授权功能。是一款非常优秀的权限管理框架。 Spring Security是一个功能强大且高度可定制的身…

vue权限控制和动态路由

思路 登录:当用户填写完账号和密码后向服务端验证是否正确,验证通过之后,服务端会返回一个token,拿到token之后(我会将这个token存贮到localStore中,保证刷新页面后能记住用户登录状态)&#xf…

颠覆你的认知,业务同事都能开发软件,我简直无地自容……

经常看到网络鼓吹业务人员也能搭建应用,本是嗤之以鼻、半信半疑,但当这件事真实发生在自己身上时,竟觉得此言不虚? 一、背景 最近公司为了集成系统、提升扩展能力,引进了低代码平台JNPF,说个题外话&#…

终于,OpenAI开放ChatGPT API,成本直降90%,百万token才2美元

现在,第三方可以通过 API 将对话模型 ChatGPT 和语音转文本模型 Whisper 集成到自己的应用程序和服务中了。 来源丨机器之心 2022 年 11 月,OpenAI 上线 ChatGPT,自此以后,这个对话模型一路开挂。毫不夸张的说,与 Ch…

4道数学题,求解极狐GitLab CI 流水线|第4题:合并列车

本文来自: 武让 极狐GitLab 高级解决方案架构师 💡 极狐GitLab CI 依靠其一体化、轻量化、声明式、开箱即用的特性,在开发者群体中的使用率越来越高,在国内企业中仅次于 Jenkins ,排在第二位。 极狐GitLab 流水线有 4…

NFT Insider #87:The Sandbox 收购游戏开发工作室 Sviper,GHST 大迁徙即将拉开帷幕

引言:NFT Insider由NFT收藏组织WHALE Members(https://twitter.com/WHALEMembers)、BeepCrypto(https://twitter.com/beep_crypto)联合出品,浓缩每周NFT新闻,为大家带来关于NFT最全面、最新鲜、最有价值的讯息。每期周…

洛必达求极限法则的通俗理解

洛必达求极限法则的通俗理解 洛必达法则是用于计算函数在某一点的极限的方法 它的基本思想是利用函数在该点的导数来逼近极限值。 洛必达法则成立的主要原因是因为它是利用函数的导数来逼近函数值的方法。当函数在某一点处存在导数时,函数的变化趋势可以由导数来…

24小时稳定性爆肝测试!国内外5款远程控制软件大盘点

本文目录前言一、ToDesk远程控制二、向日葵远程控制三、RayLink四、TeamViewer五、AnyDesk总结前言 不论你的职业是什么,从事互联网工作基本就离不开远程,从远程安装系统到远程搞设计,再到做服务器的调控,都需要靠远程来协助完成…

如何实现《电子签名法》要求的可靠电子签名?

电子文档的电子签名怎么弄?我们在工作中经常需要在一些Word、pdf等电子版文件中插入签名,而很多人可能不知道,电子签名怎么弄?怎么做电子签名才有效?电子印章或签名图片属于电子签名吗?当工作或商务交易中&…

Typroa安装教程

Markdown 是一种轻量级标记语言,创始人为约翰格鲁伯(John Gruber)。 它允许人们使用易读易写的纯文本格式编写文档,然后转换成有效的 XHTML(或者HTML)文档。这种语言吸收了很多在电子邮件中已有的纯文本标记…

中小型企业综合组网及安全配置(附拓扑图和具体实现的代码)

目录 一、实验目的 二、设备与环境 三、实验内容及要求 四、实验命令及结果 五、实验总结 六、实验报告和拓扑图下载链接 一、实验目的 1.了解企业网络建设流程 2.掌握组建中小企业网络的组网技术; 3.掌握组建中小企业网络的安全技术 二、设备与环境 微型…

Nginx国密支持问题记录

文章目录添加国密支持可能出现的问题国密不生效,查看 Nginx 可执行文件路径是否正确证书无法解析Nginx无法启动添加国密支持 NGINX添加国密支持 添加国密支持可以直接按照官网的操作顺序操作即可 参考网址:https://www.gmssl.cn/gmssl/index.jsp 可能出…

【解决】ScrollView 子 Content 在应用 Contentt Size Filter 出现位置自偏移错误问题

开发平台:Unity 2022 开发语言:CSharp 6.0   问题描述 问题表现: Scroll View 出现 Content 的 RectTransform 偏移值会出现自变化情况,但此变化情况不符合预期表现。 问题背景: Scroll View 添加 四周型 适配与 P…

Word控件Spire.Doc 【书签】教程(3): 使用 HTML 代码编辑/替换 Word 书签的内容

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下,轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具,专注于创建、编辑、转…

Python爬虫-阿里翻译_csrf

前言 本文是该专栏的第37篇,后面会持续分享python爬虫干货知识,记得关注。 笔者在前面有介绍过百度翻译的案例,感兴趣的同学,可往前翻阅查看(JS逆向-百度翻译sign)。而本文,笔者要介绍的是阿里翻译,相对于百度翻译的参数被逆向需要花点时间,阿里相对于易上手。 下面…

【java】java消息推送至微信公众号详细教程

文章目录读前必看测试号推送谁说程序员不懂浪漫? 将的关心 推送至微信公众号 给女朋友及时的关怀~(这位同学 你女朋友呢?) 读前必看 关于微信开发平台,小程序和公众号是不一样的,而公众号又会区分订阅号、服务号、测…