pytorch学习(三)——模型层

news2025/1/12 18:03:58

文章目录

  • 1. 自定义模型层
  • 2. 使用预训练模型
  • 3. 模型构建风格
    • 3.1 使用 `add_module` 方法
    • 3.2 添加进 `Sequential`
    • 3.3 Sequential作为模型容器
    • 3.4 ModuleList作为模型容器
    • 3.5 ModuleDict作为模型容器

当我们构建了数据管道能够将数据一个batch一个batch的取出来后,下一步就是构建模型了,模型的构建将很大程度的影响学习的效果,pytorch的模型层全部都在 torch.nn 模块下。

1. 自定义模型层

如果需要查看模型层的各个API及每个API的作用,大家可以去官网查看,网址放在这里了:https://pytorch.org/docs/stable/nn.html#shuffle-layers
在这里插入图片描述

如果需要自己定义模型层,那么需要继承 nn.Module 模块,类的初始化方法 __init__ 第一行必须调用父类的方法并定义模型层,必须实现 forward(input) 方法并将各层连接起来才行。

下面展示一个示例:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()#调用父类的初始化方法
        # 3层lstm
        self.lstm = nn.LSTM(input_size = 3,hidden_size = 3,num_layers = 5,batch_first = True)
        self.linear = nn.Linear(3,3)
        self.block = Block()

	# 重写forward方法
    def forward(self,x_input):
    	# 定义网络参数传递过程
        x = self.lstm(x_input)[0][:,-1,:]
        x = self.linear(x)
        y = self.block(x,x_input)
        return y

2. 使用预训练模型

因为迁移学习所带来的的影响,使用预训练的模型往往能够带来更好的效果,在pytorch中,很多预训练模型都集成到了 torchvision 中的 models 模块,可以在官网中查看支持的各个已训练好的模型,官网网址 https://pytorch.org/vision/stable/models.html,调用方法也十分简单,比如调用残差神经网络,只需要使用下面的语句即可

from torchvision import models
model = models.resnet152(pretrained=True)

上面的 pretrained=True 表示将模型下载下来,默认的下载路径为 C:\Users\ASUS\.cache\torch\hub\checkpoints ,下载的模型都存储在这里。
在这里插入图片描述

3. 模型构建风格

pytorch构建模型时有许多中风格可以用来添加网络层,第一种就是上面的那种继承 nn.Module 并且自定义类的风格,下面还有几种风格。

3.1 使用 add_module 方法

使用 add_module 能够王模型中添加模型层,示例如下

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,1))
net.add_module("sigmoid",nn.Sigmoid())

print(net)

3.2 添加进 Sequential

net = nn.Sequential(
    nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
    nn.MaxPool2d(kernel_size = 2,stride = 2),
    nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
    nn.MaxPool2d(kernel_size = 2,stride = 2),
    nn.Dropout2d(p = 0.1),
    nn.AdaptiveMaxPool2d((1,1)),
    nn.Flatten(),
    nn.Linear(64,32),
    nn.ReLU(),
    nn.Linear(32,1),
    nn.Sigmoid()
)

print(net)

这种方式构建时不能给每个层指定名称。

3.3 Sequential作为模型容器

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1))
        )
        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        x = self.conv(x)
        y = self.dense(x)
        return y 

3.4 ModuleList作为模型容器

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,1),
            nn.Sigmoid()]
        )
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
net = Net()
print(net)

3.5 ModuleDict作为模型容器

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.layers_dict = nn.ModuleDict({"conv1":nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),
               "pool": nn.MaxPool2d(kernel_size = 2,stride = 2),
               "conv2":nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
               "dropout": nn.Dropout2d(p = 0.1),
               "adaptive":nn.AdaptiveMaxPool2d((1,1)),
               "flatten": nn.Flatten(),
               "linear1": nn.Linear(64,32),
               "relu":nn.ReLU(),
               "linear2": nn.Linear(32,1),
               "sigmoid": nn.Sigmoid()
              })
    def forward(self,x):
        layers = ["conv1","pool","conv2","pool","dropout","adaptive",
                  "flatten","linear1","relu","linear2","sigmoid"]
        for layer in layers:
            x = self.layers_dict[layer](x)
        return x
net = Net()
print(net)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/189.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序函数处理之保姆级讲解

目录 生命周期函数 生命周期函数的调用过程 页面事件函数 页面路由管理 自定义函数 setData设值函数 生命周期函数 在使用Page()构造器注册页面时,需要使用生命周期函数,包括onLoad()页面加载时生命周…

硬件工程师成长之路(10.1)——芯片选型

系列文章目录 1.元件基础 2.电路设计 3.PCB设计 4.元件焊接 5.板子调试 6.程序设计 7.算法学习 8.编写exe 9.检测标准 10.项目举例 11.职业规划 文章目录前言一、电机驱动类1 、直流电机驱动芯片2、步进电机③、资料前言 送给大学毕业后找不到奋斗方向的你(每周…

【车间调度】基于全球邻域和爬坡来优化模糊柔性作业车间调度问题(Matlab代码实现)

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑…

VS Studio 搭建跨平台开发环境

VS Studio 搭建跨平台开发环境 增加VS的工作负载 打开Visual Studio Installer 安装器,点击修改 在这个界面找到Linux开发环境,勾上然后在点击右下角的修改等待安装。我的是因为已经有了所以下面那里显示的是关闭,没有的是显示的修改 等待安…

LabVIEW强制重新安装无法运行或损坏的NI软件

LabVIEW强制重新安装无法运行或损坏的NI软件 可以参考附件的录像说明。LabVIEW强制重新安装无法运行或损坏的NI软件 - 北京瀚文网星科技有限公司 (bjcyck.com) 某些NI软件,工具包或驱动程序已损坏,损坏或无法按预期运行,想尝试重新安装以进…

【ArchSummit】众安金融微服务架构演进实战

前言 📫 作者简介:小明java问道之路,专注于研究 Java/ Liunx内核/ C及汇编/计算机底层原理/源码,就职于大型金融公司后端高级工程师,擅长交易领域的高安全/可用/并发/性能的架构设计与演进、系统优化与稳定性建设。 &a…

网络原理——传输层_UDP

JavaEE传送门JavaEE JavaEE——No.2 套接字编程(TCP) JavaEE——网络原理_应用层 目录传输层UDP传输层 端到端之间的传输, 重点关注的是起点和终点 核心的协议有两个: UDP: 无连接, 不可靠传输,面向数据报, 全双工 TCP: 有链接, 可靠传输, 面向字节流, 全双工 UDP UDP协议…

nginx+tomcat(二)

四层代理: 四层代理: 一般使用七层代理也就是http应用层代理,可以反向代理和负载均衡。但是项目要使用长连接,此时内网服务器肯定不能暴漏,还是需要接入层网关进行转发,一般有使用lvs,lvs专门用作四层代理和负载均衡基…

【C++】模板初阶

文章目录一、泛型编程二、函数模板1、概念与格式2、底层原理3、实例化4、参数的匹配规则三、类模板1、概念与格式2、实例化一、泛型编程 我们通过实现一个通用的交换函数来引入泛型编程: void Swap(int& left, int& right) {int temp left;left right;r…

Linux 命令(147) —— truncate 命令

文章目录1.命令简介2.命令格式3.选项说明4.常用示例参考文献1.命令简介 truncate 将文件的大小缩小或扩展到指定的大小。 如果指定的文件不存在将被创建。 如果文件大于指定的大小,则会丢失额外的数据。如果较短,它将被扩展,扩展的稀疏部分…

【牛客刷题--SQL篇】多表查询组合查询SQL25 查找山东大学或者性别为男生的信息

💖个人主页:与自己作战 💯作者简介:CSDN博客专家、CSDN大数据领域优质创作者、CSDN内容合伙人、阿里云专家博主 💞牛客刷题系列篇:【SQL篇】】【Python篇】【Java篇】 📌推荐刷题网站注册地址&a…

Python数据分析与挖掘————图像的处理

系列文章目录 文章目录系列文章目录前言图片的马赛克一.安装matplotlib,numpy等模块二.马赛克图片一.导入图片二.定位区域三.图片的合成图片拼接图像的灰度化一.max()方法二.min()方法三.平均值法mean()函数四.加权平均值法图片的分割总结源代…

基于tauri+vue3.x多开窗口|Tauri创建多窗体实践

最近一种在捣鼓 Tauri 集成 Vue3 技术开发桌面端应用实践,tauri 实现创建多窗口,窗口之间通讯功能。 开始正文之前,先来了解下 tauri 结合 vue3.js 快速创建项目。 tauri 在 github 上star高达53K,而且呈快速增长趋势。相比elect…

DDoS报告团伙规模

攻击资源活跃度分析 在攻击源活时间的监测中发现,和 2019 年趋势一致,存活时间大于 10 天的攻击资源占比 11%。像这种能够长期被控制的肉鸡大部分都是物联网 设备,物联网设备大都存在设备系统老,人员维 护少,更新慢等…

vue当中的事件处理

1.绑定监听v-on 最简单的一个绑定监听的事件 <body><div id"root"><h1>my name is {{name}}</h1><button v-on:click"showInfo">click me</button></div><script type"text/javascript">Vue.…

HotSpot 虚拟机对象探秘-对象的创建、内存布局、访问定位

目录对象的创建检查类的符号引用&#xff0c;是否执行过类的加载过程分配内存指针碰撞&#xff1a;空闲列表&#xff1a;线程安全的问题&#xff0c;对分配内存空间的动作进行同步处理——TLAB初始化虚拟机对对象进行必要的设置&#xff0c;执行构造方法对象的内存布局对象头包…

Spring、MySQL、日期、BigDecimal、集合、反射、序列化中的坑与使用指南

文章目录MySQL中的坑MySQL断开连接Mysql表字段设置为not null如何解决网络瓶颈核心流程的性能查看Spring中的坑与使用注意springboot的配置文件先后顺序定时任务不进行lombok的不适用场景Spring的Bean默认名称生成规则new出来的对象不被Spring所管理SpringBean相关的注解Spring…

Java 类和对象 详解+通俗易懂

文章目录类和对象1. 面对对象的初步认识1.1 什么是面向过程&#xff1f;什么又是面向对象&#xff1f;1.2 对象、成员变量和成员方法的关系和理解2. 类的定义和使用2.1 简单认识类2.2 类的定义格式2.3 小试身手3. 类的实例化3.1 什么是实例化3.2 类和对象的说明4. this 引用4.1…

k8s上部署seata-server集群并注册到nacos上

部署前准备 第一步&#xff1a; 创建seata-server需要的表,有现成的阿里云RDS&#xff0c;就直接在RDS上创建数据库了&#xff0c;方便后面统一管理。 具体的 SQL 参考script/server/db &#xff0c;这里使用的是 MySQL 的脚本&#xff0c;数据库名称为 seata&#xff0c;还需…

对外 API 接口,请把握这3 条原则,16 个小点

对外API接口设计 安全性 1、创建appid,appkey和appsecret 2、Token&#xff1a;令牌&#xff08;过期失效&#xff09; 3、Post请求 4、客户端IP白名单 &#xff08;可选&#xff09; 5、单个接口针对IP限流&#xff08;令牌桶限流&#xff0c;漏桶限流&#xff0c;计数器…