动手学深度学习(三)深度学习计算

news2024/11/24 20:04:51

一、模型构造

1、继承Module类来构造模型来构造模型

class MLP(nn.Module):
    # 声明带有模型参数的层,这里声明了两个全连接层
    def __init__(self, **kwargs):
        # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
        # 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Linear(784, 256) # 隐藏层
        self.act = nn.ReLU()
        self.output = nn.Linear(256, 10)  # 输出层
         

    # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

2、Sequential类继承自Block

Sequential类它提供add函数来逐一添加串联的Module子类实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算

net = MySequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10), 
        )
print(net)
net(X)

3、ModuleList

①定义

ModuleList 是 PyTorch 中的一种容器类,位于 torch.nn 模块下,专门用于存储多个子模块(即网络层)

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 类似List的索引访问
print(net)

②ModuleList 和 Python 普通列表的区别

  • 注册模块ModuleList 中的所有子模块都会被注册为模型的一部分。PyTorch 会自动识别并将它们的参数纳入模型的训练和保存中。而普通的 Python 列表并不会注册其中的模块。
  • 参数追踪:使用 ModuleList 后,model.parameters() 可以追踪到列表中的所有模块参数。如果使用普通列表,模型中的这些层的参数将不会被自动管理。
(1)ModuleList
class Module_ModuleList(nn.Module):
    def __init__(self):
        super(Module_ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10)])
(2)Python列表
class Module_List(nn.Module):
    def __init__(self):
        super(Module_List, self).__init__()
        self.linears = [nn.Linear(10, 10)]

由结果可以看出,使用了nn.ModuleList([nn.Linear(10, 10)]),自动注册了模块并进行参数追踪,而使用列表 [nn.Linear(10, 10)]定义的参数将不会被自动管理。

 

4、ModuleDict类

ModuleDictPyTorch 中 torch.nn 模块下的一个容器类专门用于存储多个子模块,并以字典的形式组织这些子模块。与 Python 的普通字典不同,ModuleDict 中的子模块会被自动注册为模型的一部分,这使得 PyTorch 可以自动追踪、保存和加载这些模块及其参数

net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问

二、模型参数的访问初始化和共享

init模块,它包含了多种模型初始化方法。

1、访问模型参数

net.named_parameters()

net.named_parameters() : PyTorch 中的一个方法,用于返回模型中所有可训练参数的名称和参数本身(权重和偏置)

print(type(net.named_parameters()))
for name, param in net.named_parameters():
    print(name, param.size())

② nn.Parameter

nn.Parameter:用于定义可以被优化(即可以通过梯度下降等算法进行训练)的参数。当你创建一个 nn.Parameter 对象时,它会自动注册到模型的参数列表中,这意味着它将被包含在模型的参数优化过程中。

class MyModel(nn.Module):
    def __init__(self, **kwargs):
        super(MyModel, self).__init__(**kwargs)
        self.weight1 = nn.Parameter(torch.rand(20, 20))
        self.weight2 = torch.rand(20, 20)
    def forward(self, x):
        pass

初始化权重的梯度是None,训练过程中回代才改变。

③参数的数值和梯度访问

param.data和param.grad访问和修改相关属性。

for name, param in net.named_parameters():
    if 'weight' in name:
        init.normal_(param, mean=0, std=0.01)
        print(name, param.data)
        print(name, param.grad)

2、初始化模型参数

①使用init中的方法初始化

下面代码分别是正态分布初始化和常数初始化。

init.normal_(param, mean=0, std=0.01)
init.constant_(param, val=0)

②自定义初始化

参数初始化时使用with torch.no_grad()来暂时禁用梯度计算,这对于初始化权重是有用的,因为我们不希望在初始化时计算梯度。

def init_weight_(tensor):
    with torch.no_grad():
        tensor.uniform_(-10, 10)
        tensor *= (tensor.abs() >= 5).float()

3、共享模型参数

当不同层指向的是同一个实例时,它们共享同样的权重。如果你初始化或更新其中一个层的参数,实际上这几个层都会受到映像。

linear = nn.Linear(1, 1, bias=False)
net = nn.Sequential(linear, linear) 
print(net)
for name, param in net.named_parameters():
    init.constant_(param, val=3)
    print(name, param.data)

三、自定义层

1、不含模型参数的自定义层

class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    def forward(self, x):
        return x - x.mean()
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

2、含模型参数的自定义层

class MyListDense(nn.Module):
    def __init__(self):
        super(MyListDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))

    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x
net = MyListDense()
print(net)

四、读取和存储

1、读写Tensor

torch.save():将张量存到指定文件中。

torch.load():载入指定文件中的张量。

y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

2、读写模型

state_dict()方法:

  • 保存模型的参数:通过 state_dict(),你可以将模型的参数提取出来并保存为一个字典,以便稍后加载或分享。
  • 加载模型的参数:可以通过 load_state_dict() 方法将保存的参数字典加载到模型中。
  • 检查模型的当前参数状态state_dict() 方便调试时检查模型的权重和偏置。
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
net.state_dict()

PATH = "./net.pt"
torch.save(net.state_dict(), PATH)

net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2129742.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用CubeMX复现正点原子TFTLCD驱动例程

来源:正点原子 FMC的工作原理暂时先欠着,先记录一下CRUD的过程。 第一步准备一个us级别延时函数,不会的参考拙作:STM32的定时器简介-CSDN博客 第二部开启FMC外设: ①进入 Pinout->FMC 配置栏,配置 …

【隐私计算】Paillier半同态加密算法

一、何为同态加密(HE)? HE是一种特殊的加密方法,它允许直接对加密数据执行计算,如加法和乘法,而计算过程不会泄露原文的任何信息。计算的结果仍然是加密的,拥有密钥的用户对处理过的密文数据进…

树莓派5开发板-安装Raspberry Pi系统-学习记录1

树莓派5开发板介绍 树莓派5(Raspberry Pi 5)是树莓派系列最新的开发板,相较于前几代产品,它在性能、连接性和功能方面都有了显著提升。以下是树莓派5的一些主要特点: 处理器:树莓派5搭载了Broadcom BCM27…

如何基于gpt模型抢先打造成功的产品

来自:Python大数据分析 费弗里 ChatGPT、gpt3.5以及gpt4,已然成为当下现代社会中几乎人尽皆知的话题,而当此种现象级产品引爆全网,极大程度上吸引大众注意力的同时,有一些嗅觉灵敏的人及时抓住了机会,通过快…

【FreeRL】我的深度学习库构建思想

文章目录 前言参考python环境效果已复现结果 综述DQN.py(主要)算法实现参数修改细节实现显示训练,保存训练 Buffer.pyevaluate.pylearning_curves 前言 代码实现在:https://github.com/wild-firefox/FreeRL 欢迎star 参考 动手学强化学习e…

Coggle数据科学 | 小白学 RAG:Milvus 介绍与使用教程

本文来源公众号“Coggle数据科学”,仅用于学术分享,侵权删,干货满满。 原文链接:小白学 RAG:Milvus 介绍与使用教程 什么是Milvus? Milvus 是一款高性能、高扩展性的开源向量数据库,专为处理…

【阿一网络安全】如何让你的密码更安全?(三) - 散列函数

散列函数 散列函数(Hash Function,又称散列算法、哈希函数),是一种从任何一种数据中创建小的数字指纹的方法。 散列值 散列函数,把任意长的消息明文,压缩成摘要,使得数据量变小,将…

[数据集][目标检测]脊椎检测数据集VOC+YOLO格式1137张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):1137 标注数量(xml文件个数):1137 标注数量(txt文件个数):1137 标注…

SpringBoot2:web开发常用功能实现及原理解析-上传与下载

文章目录 一、上传文件1、前端上传文件给Java接口2、Java接口上传文件给Java接口 二、下载文件1、前端从Java接口下载文件2、Java接口调用Java接口下载文件 一、上传文件 1、前端上传文件给Java接口 Controller接口 此接口支持上传单个文件和多个文件,并保存在本地…

基于小程序的教学辅助微信小程序设计+ssm(lw+演示+源码+运行)

教学辅助微信小程序 摘 要 随着移动应用技术的发展,越来越多的学生借助于移动手机、电脑完成生活中的事务,许多的传统行业也更加重视与互联网的结合,由于学生学习的压力越来越大,教学辅助是一个非常不错的教育平台,对…

人工智能(AI)领域各方向顶会和顶刊

在人工智能(AI)这个快速发展的领域,研究人员和从业者需要紧跟最新的研究动态和技术进展。顶级的会议和期刊是获取最新科研成果和交流思想的重要平台。以下是人工智能领域内不同方向的顶级会议和期刊概览。 顶级会议 人工智能基础与综合 A…

客厅无主灯设计:灯位布局与灯光灯具的和谐搭配

在现代家居设计中,客厅作为家庭活动的中心区域,其照明设计的重要性不言而喻。无主灯设计以其灵活多变、氛围营造独特的优势,逐渐成为客厅照明的热门选择。然而,如何合理规划灯位布局,并科学搭配灯光与灯具,…

20240913 每日AI必读资讯

AMD死战CUDA:我是一家软件公司 - AMD重大改变:重心将从硬件开发转向强调软件开发、API 和 AI 体验。 - 软件工程团队规模扩大了三倍,并且全力以赴投入软件开发 - AMD将自家已有5年历史的图形架构RDNA、计算架构CDNA重新整合在一起&#xf…

计算机毕业设计选题推荐-在线拍卖系统-Java/Python项目实战

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

2024年9月13日 十二生肖 今日运势

小运播报:2024年9月13日,星期五,农历八月十一 (甲辰年癸酉月庚辰日),法定工作日。 红榜生肖:猴、鼠、鸡 需要注意:牛、兔、狗 喜神方位:西北方 财神方位:…

windows10通过coursier安装scala

第一步:安装Java 参考菜鸟教程安装Java: https://www.runoob.com/java/java-environment-setup.html#win-install 第二步:安装coursier 进入https://www.scala-lang.org/download/ 如下图所示: 第三步、确定jdk对应的scala版本…

小琳AI课堂:MASS模型——革新自然语言处理的预训练技术

大家好,这里是小琳AI课堂。今天我们来聊聊一个在自然语言处理(NLP)领域非常热门的话题——MASS模型,全称是Masked Sequence to Sequence Pre-training for Language Generation。这是华为诺亚方舟实验室在2019年提出的一种创新模型…

cpp-httplib的下载和使用

cpp-httplib的下载和使用 1.httplib 简介2. httplib 使用2.1 协议接口2.2 双端接口2.3 实际使用 3. 对Server中的Handler回调函数进行分析4. 最后 1.httplib 简介 cpp-httplib(也称为 httplib)是一个基于 C 的轻量级 HTTP 框架,它提供了简单…

统一建模语言UML之类图(Class Diagram)(表示|关系|举例)

文章目录 1.UML2.Class Diagram2.1 类图的表示2.2 类间的关系2.2.1 关联2.2.2 聚合2.2.3 组合2.2.4 泛化(继承)2.2.5 实现(接口实现)2.2.6 依赖 2.3 类图的作用 参考:Class Diagram | Unified Modeling Language (UML)…

2024/9/12 数学“回头看”之R(a)与R(a※)、分布函数、概率密度的特点

注意!这是充分必要条件。 分布函数性质 概率密度性质: