Pytorch神经网络模型nn.Sequential与nn.Linear

news2024/11/15 21:24:28

1、定义模型

对于标准深度学习模型,我们可以使用框架的预定义好的层。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。

我们首先定义一个模型变量net,它是一个Sequential类的实例。 Sequential类将多个层串联在一起。 当给定输入数据时,Sequential实例将数据传入到第一层, 然后将第一层的输出作为第二层的输入,以此类推。

在下面的例子中,我们的模型只包含一个层,因此实际上不需要Sequential。 但是由于以后几乎所有的模型都是多层的,在这里使用Sequential会让你熟悉“标准的流水线”。

回顾 图3.1.2中的单层网络架构, 这一单层被称为全连接层(fully-connected layer), 因为它的每一个输入都通过矩阵-向量乘法得到它的每个输出。

在PyTorch中,全连接层在Linear类中定义。 值得注意的是,我们将两个参数传递到nn.Linear中。 第一个指定输入特征形状,即2,第二个指定输出特征形状,输出特征形状为单个标量,因此为1。

import torch

# nn是神经网络的缩写
from torch import nn

net = nn.Sequential(nn.Linear(2, 1))

2、初始化模型参数

在使用net之前,我们需要初始化模型参数。 如权重和偏置。 深度学习框架通常有预定义的方法来初始化参数。 在这里,我们指定每个权重参数应该从均值为0、标准差为0.01的正态分布中随机采样, 偏置参数将初始化为零。

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

3、定义损失函数

计算均方误差使用的是MSELoss类,也称为平方L2范数。 默认情况下,它返回所有样本损失的平均值。

loss = nn.MSELoss()

4、定义优化算法

小批量随机梯度下降算法是一种优化神经网络的标准工具, PyTorch在optim模块中实现了该算法的许多变种。 当我们实例化一个SGD实例时,我们要指定优化的参数 (可通过net.parameters()从我们的模型中获得)以及优化算法所需的超参数字典。 小批量随机梯度下降只需要设置lr值,这里设置为0.03。

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

5、训练

通过深度学习框架的高级API来实现我们的模型只需要相对较少的代码。 我们不必单独分配参数、不必定义我们的损失函数,也不必手动实现小批量随机梯度下降。 

在每个迭代周期里,我们将完整遍历一次数据集(train_data), 不停地从中获取一个小批量的输入和相应的标签。 对于每一个小批量,我们会进行以下步骤:

  • 通过调用net(X)生成预测并计算损失l(前向传播)。
  • 通过进行反向传播来计算梯度。
  • 通过调用优化器来更新模型参数

为了更好的衡量训练效果,我们计算每个迭代周期后的损失,并打印它来监控训练过程

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X) ,y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

# epoch 1, loss 0.000248
# epoch 2, loss 0.000103
# epoch 3, loss 0.000103

下面比较生成数据集的真实参数通过有限数据训练获得的模型参数。要访问参数,我们首先从net访问所需的层,然后读取该层的权重和偏置。

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

# w的估计误差: tensor([-0.0010, -0.0003])
# b的估计误差: tensor([-0.0003])

batchsize的选择和学习率调整

batchsize的选择和学习率调整_batchsize和学习率-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1408613.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis:定时清理垃圾图片

首先理解清理垃圾文件的原理 在填写表单信息时上传图片后就已经存入云中,但是此时取消表单的填写这个图片就变成垃圾图片,所以在点击新建填写表单的方法/upload中,把上传的图片名字存入Redis的value中,key(key1&#…

深入理解badblocks

文章目录 一、概述二、安装2.1、源码编译安装2.2、命令行安装2.3、安装确认 三、重要参数详解3.1、查询支持的参数3.2、参数说明 四、实例4.1、全面扫描4.2、破坏性写入并修复4.3、非破坏性写入测试 五、实现原理六、注意事项 团队博客: 汽车电子社区 一、概述 badblocks命令是…

NODE笔记 2 使用node操作飞书多维表格

前面简单介绍了node与简单的应用,本文通过结合飞书官方文档 使用node对飞书多维表格进行简单的操作(获取token 查询多维表格recordid,删除多行数据,新增数据) 文章目录 前言 前两篇文章对node做了简单的介绍&#xff…

【若依】关于对象查询list返回,进行业务处理以后的分页问题

1、查询对象Jglkq返回 list,对 list 进行业务处理后返回,但分页出现问题。 /*** 嫁功率考勤查询*/RequiresPermissions("hr:kq:list")PostMapping("/list")ResponseBodypublic TableDataInfo list(Jglkq jglkq) throws ParseExcepti…

骨传导耳机综评:透视南卡、韶音和墨觉三大品牌的性能与特点

在当前的蓝牙音频设备领域中,骨传导蓝牙运动耳机以其出色的安全特性和舒适的体验,受到了健身爱好者们的广泛好评。这类耳机不同于我们常见的入耳式耳机,它的工作方式是直接通过振动将声音传递到用户的耳骨中,这样既可以享受音乐&a…

鸿蒙系统的APP开发

鸿蒙系统(HarmonyOS)是由华为公司开发的一款分布式操作系统。它被设计用于在各种设备上实现无缝的、统一的用户体验,包括智能手机、平板电脑、智能电视、智能穿戴等设备。鸿蒙系统的核心理念是支持多终端协同工作,使应用能够更灵活…

unity学习笔记----游戏练习06

一、豌豆射手的子弹控制 创建脚本单独控制子弹的运动 用transform来控制移动 void Update() { transform.Translate(Vector3.right * speed * Time.deltaTime); } 创建一个控制子弹速度的方法,方便速度的控制 private void SetSpeed(float spee…

DS:顺序表的实现(超详细!!)

创作不易,友友们给个三连呗! 本文为博主在DS学习阶段的第一篇博客,所以会介绍一下数据结构,并在最后学习对顺序表的实现,在友友们学习数据结构之前,一定要对三个部分的知识——指针、结构体、动态内存管理的…

springboot118共享汽车管理系统

简介 【毕设源码推荐 javaweb 项目】基于springbootvue 的共享汽车管理系统 适用于计算机类毕业设计,课程设计参考与学习用途。仅供学习参考, 不得用于商业或者非法用途,否则,一切后果请用户自负。 看运行截图看 第五章 第四章 获…

Java 集合List相关面试题

📕作者简介: 过去日记,致力于Java、GoLang,Rust等多种编程语言,热爱技术,喜欢游戏的博主。 📗本文收录于java面试题系列,大家有兴趣的可以看一看 📘相关专栏Rust初阶教程、go语言基…

图形用户界面(GUI)开发教程

文章目录 写在前面MATLAB GUI启动方式按钮(Push Button)查看属性tag的命名方式回调函数小小的总结 下拉菜单(Pop-up Menu)单选框(Radio Button)和复选框(Check Box)静态文本&#xf…

将vue组件发布成npm包

文章目录 前言一、环境准备1.首先最基本的需要安装nodejs,版本推荐 v10 以上,因为需要安装vue-cli2.安装vue-cli 二、初始化项目1.构建项目2.开发组件/加入组件3. 修改配置文件 三、调试1、执行打包命令2、发布本地连接包3、测试项目 四、发布使用1、注册…

开源客户沟通平台Chatwoot账号激活问题

安装docker docker-compose 安装git clone https://github.com/chatwoot/chatwoot 下载之后根目录有一个docker-compose.production.yaml将其复制到一个目录 重命名 docker-compose.yaml 执行docker-compose up -d 构建 构建之后所有容器都安装好了 直接访问http://ip:3…

基于 Docker 部署 Pingvin Share 文件共享平台

一、Pingvin Share 介绍 Pingvin Share 简介 Pingvin Share 是自托管文件共享平台,是 WeTransfer 的替代方案。 Pingvin Share 特点 在 2 分钟内启动您的实例使用可通过链接访问的文件创建共享没有文件大小限制,只有你的磁盘是你的限制设置共享到期时间…

Mysql全局优化

Mysql全局优化总结 从上图可以看出SQL及索引的优化效果是最好的,而且成本最低,所以工作中我们要在这块花更多时间。 补充一点配置文件my.ini或my.cnf的全局参数: 假设服务器配置为: CPU:32核内存:64GDIS…

highcharts.css文件的样式覆盖了options的series里面的color问题解决

文章目录 一、问题背景二、解决问题 一、问题背景 原本的charts我们的每个数据是有对应的color显示的,如下图: 后面我们系统做了黑白模式,引入了highcharts的css文件,结果highcharts的css文件中class的颜色样式覆盖了我们数据中的…

vue打包后与本地测试样式不同问题,element-ui样式打包部署前后样式不同。

个别文件的样式中<style>未加scope。 查找到一些文件中修改了对应页面的elementUI的样式&#xff0c;但未加scope 给<style>加上scope&#xff0c;就好了。

shell脚本登录dlut-lingshui并设置开机连网和断网重连

本文提供了一个用于无图形界面linux系统自动连接dlut-lingshui校园网的shell脚本&#xff0c;并提供了设置开机联网以及断网重连的详细操作步骤。本文的操作在ubuntu 22.04系统上验证有效&#xff0c;在其他版本的linux系统上操作时遇到问题可以自行百度。 1. 获取校园网认证界…

面试题之RocketMq

1. RocketMq的组成及各自的作用&#xff1f; 在RocketMq中有四个部分组成&#xff0c;分别是Producer&#xff0c;Consumer&#xff0c;Broker&#xff0c;以及NameServer&#xff0c;类比于生活中的邮局&#xff0c;分别是发信者&#xff0c;收信者&#xff0c;负责暂存&#…

C#学习(十)——WPF重构与美化

一、Entity Framework Core 特点&#xff1a;【跨平台】&#xff0c;【建模】&#xff0c;【查询、更改、保存】&#xff0c;【并发】&#xff0c;【事务】&#xff0c;【缓存】&#xff0c;【数据迁移】 EF的组件 二、重构&#xff1a;构建数据模型 项目延续C#学习(九)的 项…