深度学习(15)--PyTorch构建卷积神经网络

news2025/1/16 15:55:36

目录

一.PyTorch构建卷积神经网络(CNN)详细流程

二.graphviz + torchviz使PyTorch网络可视化

2.1.可视化经典网络vgg16

2.2.可视化自己定义的网络


一.PyTorch构建卷积神经网络(CNN)详细流程

卷积神经网络(Convolutional Neural Networks)是一种深度学习模型或类似于人工神经网络的多层感知器,常用来分析视觉图像。

卷积神经网络的详细介绍可以参考博主写的文章:

深度学习(2)--卷积神经网络(CNN)-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/GodFishhh/article/details/135668789?spm=1001.2014.3001.5501

PyTorch构建神经网络的第一步均为引入神经网络包

import torch.nn as nn

卷积神经网络的构建: 

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 卷积层->激活函数->池化层
        self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)  pytorch中是channel_first的,颜色通道写在第一个位置
            nn.Conv2d(                      # 1d对结构化数据 2d对图像数据 3d对视频数据
                in_channels=1,              # 灰度图   输入的特征图数
                out_channels=16,            # 要得到几多少个特征图  输出的特征图数,也就是卷积核的个数(一个卷积核进行卷积可以得到一个特征图,所以卷积核的个数与特征图的数量相同)
                kernel_size=5,              # 卷积核大小
                stride=1,                   # 步长
                padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1  卷积后的图像大小: (h - Kernel_size + 2*p) / s + 1
            ),                              # 输出的特征图为 (16, 28, 28)
            nn.ReLU(),                      # relu层
            nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)  池化后特征数变少
        )
        self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)  
            nn.ReLU(),                      # relu层
            nn.MaxPool2d(2),                # 输出 (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果  最终数据的大小以及分类的数量

    def forward(self, x):
        # 调用卷积层
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7),分类无法对三维的数据进行处理,所以需要将三维图像拉长成一维数据再来进行分类.
        # -1是自动计算,只需给出一个维度的大小,会自动计算另外个维度.eg.5x4 -> x.view(2,-1),-1对应的就是10. 2x5x10 -> x.view(2,-1),-1对应的就是5x10
        # 在此处,给出的第一个参数x.size(0)的值为batch,所以-1对应的值就是32x7x7
        # 调用全连接层(全连接层的输入必须是二维的矩阵,上述的flattern操作将参数x变成了一个二维矩阵)
        output = self.out(x)
        return output

详解:

1.创建的神经网络构建类一定要继承nn.Module,后续要调用Module包里面的方法构建神经网络。

2.构造函数的第一步永远是调用父类的构造函数,利用super()进行调用:

super(CNN, self).__init__()

3.卷积神经网络的层次顺序一般为:卷积层-> 激活函数做非线性变换 ->池化层,并在输出之前设置一层全连接层。

4.上述代码构建的卷积神经网络是顺序Sequential的,设置有两个卷积层,两个激活函数,两个池化层,以及输出前的一个全连接层。(一般卷积一次就要池化一次)

nn.Sequential()

5.卷积层的构造:通过Module模块中的Conv2d来构造卷积层,其中参数分别为:输入图片数据的颜色通道数(第一个卷积层)/输入的特征图数(之后的卷积层)、输出的特征图数、卷积核的大小、步长、padding值。(其中Conv1d用来处理结构化数据,Conv2d用来处理图片数据,Conv3d用来处理视频数据)

此处设置的卷积层由输入的1个特征图数得到最后的32个特征图数

nn.Conv2d(1, 32, 5, 1, 2)
nn.Conv2d(16, 32, 5, 1, 2)

值得注意的是,如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1  卷积后的图像大小: (h - Kernel_size + 2*p) / s + 1。

6.此处激活函数设置的是ReLU,可以根据自己的需求设置不同的激活函数。

nn.ReLU()

7.池化层的构造: 只需要设置一个参数,即为进行池化操作的区域大小。

nn.MaxPool2d(kernel_size=2)

8.全连接层的构造:输入的数据最后经过全连接层得到输出数据,参数分别为输入数据的大小,以及最后进行分类的类别数。

self.out = nn.Linear(32 * 7 * 7, 10)

9.前向传播:PyTorch构建的神经网络,前向传播需要手动设置,此处先调用conv1和conv2两层,再将数据拉成二维的传入全连接层,得到最后的输出值。

二.graphviz + torchviz使PyTorch网络可视化

事先需要先安装graphviz库和torchviz库,graphviz具体安装步骤可以参考博主写的文章:

深度学习(9)--pydot库和graphviz库安装流程详解_pydot 怎么安装-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/GodFishhh/article/details/135929146?spm=1001.2014.3001.5501torchviz库可以直接再编译器中进行安装,也可也在cmd中对应环境中使用pip指令安装:

上述两个库安装完之后,导入网络可视化需要用到的头文件:

from torchviz import make_dot
from torchvision.models import vgg16  # 导入vgg16模型用于演示

2.1.可视化经典网络vgg16

# 随机生成一个tensor张量(对应的数据为图片有十张,图片的大小为3x32x32)
x = torch.randn(10, 3, 32, 32)
# 实例化 vgg16
model = vgg16()
# 将 x 输入网络
vgg16_out = model(x)
# 实例化 make_dot
vgg16_result = make_dot(vgg16_out)
# result.view()  直接在当前路径下保存 pdf 并打开
# 保存文件为pdf到指定路径并不打开
vgg16_result.render(filename='vgg16_net_Structure', view=False, format='pdf')

生成如下两个文件 

 

2.2.可视化自己定义的网络

# 随机生成一个tensor张量(对应的数据为图片有四张,图片的大小为1x28x28)
x = torch.randn(4, 1, 28, 28)
# 实例化 vgg16
model = CNN()
# 将 x 输入网络
CNN_out = model(x)
# 实例化 make_dot
CNN_result = make_dot(CNN_out)
# result.view()  直接在当前路径下保存 pdf 并打开
# 保存文件为pdf到指定路径并不打开
CNN_result.render(filename='CNN_net_Structure', view=False, format='pdf')

生成如下两个文件  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1439383.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每日一题 力扣993.二叉树的堂兄弟节点

993. 二叉树的堂兄弟节点 题目描述: 在二叉树中,根节点位于深度 0 处,每个深度为 k 的节点的子节点位于深度 k1 处。 如果二叉树的两个节点深度相同,但 父节点不同 ,则它们是一对堂兄弟节点。 我们给出了具有唯一值…

实现RBAC

一、菜单权限的控制 左边侧边栏的菜单显示是来自于userlnfoStore.menuRoutes,当前进行打印,得到的内容其实就是staticRoutes静态路由表定义的路由数组对象;最终staticRoutes(静态路由)、allAsyncRoutes(动态路由)、anyRoute(任意路由)需要进行一次整合、…

spring boot和spring cloud项目中配置文件application和bootstrap中的值与对应的配置类绑定处理

在前面的文章基础上 https://blog.csdn.net/zlpzlpzyd/article/details/136065211 加载完文件转换为 Environment 中对应的值之后,接下来需要将对应的值与对应的配置类进行绑定,方便对应的组件取值处理接下来的操作。 对应的配置值与配置类绑定通过 Con…

单片机的认识

单片机的定义 先简单理解为: 在一片集成电路芯片上集成了微处理器(CPU )存储器(ROM和RAM)、I/O 接口电路,构成单芯片微型计算机,即为单片机。 把组成微型计算机的控制器、运算器、存储器、输…

MQTT 服务器(emqx)搭建及使用

推荐阅读: MQTT 服务器(emqx)搭建及使用 - 哔哩哔哩 (bilibili.com) 一、EMQX 服务器搭建 1、下载EMQX https://www.emqx.com/zh/try?productbroker 官方中文手册: EMQX Docs 2、安装使用 1、该软件为绿色免安装版本,解压缩后即安装完…

重写Sylar基于协程的服务器(7、TcpServer HttpServer的设计与实现)

重写Sylar基于协程的服务器(7、TcpServer & HttpServer的设计与实现) 重写Sylar基于协程的服务器系列: 重写Sylar基于协程的服务器(0、搭建开发环境以及项目框架 || 下载编译简化版Sylar) 重写Sylar基于协程的服务…

高考志愿填报模拟系统的功能和技术总结

一、金秋志愿高考志愿填报系统主要功能: 用户注册与登录:允许学生和家长注册账号,使用注册的账号登录系统。 个人信息管理:允许用户查看、修改个人信息,如姓名、性别、联系方式等。 高考成绩输入:学生输…

删除和清空Hive外部表数据

外部表和内部表区别 未被external修饰的是内部表(managed table),被external修饰的为外部表(external table); 区别: 内部表数据由Hive自身管理,外部表数据由HDFS管理; …

STM32——LCD(1)认识

目录 一、初识LCD 1. LCD介绍 2. 显示器的分类 3. 像素 4. LED和OLED显示器 5. 显示器的基本参数 (1)像素 (2)分辨率 (3)色彩深度 (4)显示器尺寸 (5&#xff…

Redis篇之缓存雪崩

一、什么的缓存雪崩 缓存雪崩:在同一时间段大量的缓存key同时失效或者redis服务宕机,导致大量请求到达数据库给数据库带来巨大压力,可能导致数据库崩了。 二、应该怎么解决 1.给不同的Key的TTL添加随机值 2.利用Redis集群提高服务的可用性 3…

夜天之书 #95 GreptimeDB 社群观察报告

GreptimeDB 是格睿科技(Greptime)公司研发的一款开源时序数据库,其源代码[1]在 GitHub 平台公开发布。 https://github.com/GreptimeTeam/greptimedb 我从 2022 年开始知道有 GreptimeDB 这个项目。2023 年,我注意到他们的 Commun…

IntelliJ IDEA 2023.3发布,AI 助手出世,新特性杀麻了!!

目录 关键亮点 对 Java 21 功能的完全支持 调试器中的 Run to Cursor(运行到光标)嵌入选项 带有编辑操作的浮动工具栏 用户体验优化 Default(默认)工具窗口布局选项 默认颜色编码编辑器标签页 适用于 macOS 的新产品图标 Speed Sear…

【buuctf--来首歌吧】

用 Audacity 打开,左声道部分可以放大,可以按照长短转换成摩斯密码,放大后: ..... -... -.-. ----. ..--- ..... -.... ....- ----. -.-. -... ----- .---- ---.. ---.. ..-. ..... ..--- . -.... .---- --... -.. --... ----- -…

async 与 await(JavaScript)

目录捏 前言一、async二、await三、使用方法总结 前言 async / await 是 ES2017(ES8) 提出的基于 Promise 解决异步的最终方案。上一篇文章介绍了 回调地狱 与 Promise(JavaScript),因为 Promise 的编程模型依然充斥着大量的 then 方法&#…

canvas绘制横竖坐标轴(带有箭头和刻度)

查看专栏目录 canvas实例应用100专栏,提供canvas的基础知识,高级动画,相关应用扩展等信息。canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重…

Redis的数据类型Hash使用场景实战

Redis的数据类型Hash使用场景 常见面试题:redis在你们项目中是怎么用的,除了String数据类型还使用什么数据类型? 怎么保证缓存和数据一致性等问题… Hash模型使用场景 知识回顾: redisTemplate.opsForHash() 方法是 Redis 的 …

基于全连接神经网络模型的手写数字识别

基于全连接神经网络模型的手写数字识别 一. 前言二. 设计目的及任务描述2.1 设计目的2.2 设计任务 三. 神经网络模型3.1 全连接神经网络模型方案3.2 全连接神经网络模型训练过程3.3 全连接神经网络模型测试 四. 程序设计 一. 前言 手写数字识别要求利用MNIST数据集里的70000张…

Multisim14.0仿真(五十五)汽车转向灯设计

一、功能描述: 左转向:左侧指示灯循环依次闪亮; 右转向:右侧指示灯循环依次闪亮; 刹车: 所有灯常亮; 正常: 所有灯熄灭。 二、主要芯片: 74LS161D 74LS04D 74…

深入理解Spark BlockManager:定义、原理与实践

深入理解Spark BlockManager:定义、原理与实践 1.定义 Spark是一个开源的大数据处理框架,其主要特点是高性能、易用性以及可扩展性。在Spark中,BlockManager是其核心组件之一,它负责管理内存和磁盘上的数据块,并确保…

通过docker-compose部署NGINX服务,并使该服务开机自启

要在通过docker-compose部署的NGINX服务实现开机自启,你需要确保Docker守护进程在系统启动时自动运行,并配置docker-compose.yml文件以在容器中运行NGINX服务。以下是步骤: 确保Docker守护进程开机启动: 在Ubuntu/Debian上&#x…