PyTorch实战4:猴痘病识别

news2024/12/22 22:53:00
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P4周:猴痘病识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

目录

    • 一、搭建CNN网络结构
      • 1、原文网络结构
        • 1.1、网络结构赘述
        • 1.2、前向传播实现
        • 1.3、原型网络结构
      • 2、调整网络结构
        • 2.1、提升精度方案
        • 2.2、网络结构调整
      • 3、结果成效与对比
        • 3.1、原模型结果可视化
        • 3.2、调整后结果可视化
    • 二、保存并加载模型

本次实战主要学习内容:

  • 调整网络结构使测试集accuracy到达88%
  • 训练过程中保存效果最好的模型参数

一、搭建CNN网络结构

  1. 读取并加载本地数据以及数据增强可移步至PyTorch实战3:天气识别进行学习
  2. 卷积层、全连接层、池化层以及批量归一化层的详解可移步至PyTorch实战2:彩色图片识别(CIFAR10)

1、原文网络结构

1.1、网络结构赘述

代码定义了一个名为Network_bn的类,继承自nn.Module。

该类实现了一个卷积神经网络(CNN),包含了多个卷积层和池化层。

在类的初始化函数中,首先调用父类nn.Module的构造函数以进行初始化操作。

然后依次创建了六个神经网络层:

  • 两个卷积层、两个Batch Normalization层、一个最大池化层和一个全连接层。
  • 其中Conv2d函数用于创建二维卷积层,
  • BatchNorm2d函数用于创建二维批量归一化层,
  • MaxPool2d函数用于创建最大池化层,
  • Linear函数用于创建全连接层。

网络结构详解:

  • self.conv1是输入为3通道,输出为12通道,卷积核大小为5x5的卷积层;
  • self.bn1是12通道的Batch Normalization层;
  • self.conv2是输入为12通道,输出为12通道,卷积核大小为5x5的卷积层;
  • self.bn2是12通道的Batch Normalization层;
  • self.pool是2x2最大池化层;
  • self.conv4是输入为12通道,输出为24通道,卷积核大小为5x5的卷积层;
  • self.bn4是24通道的Batch Normalization层;
  • self.conv5是输入为24通道,输出为24通道,卷积核大小为5x5的卷积层;
  • self.bn5是24通道的Batch Normalization层;
  • self.fc1是全连接层,输入大小为24x50x50,输出大小为类别数(len(classNames))。

具体实现了以下结构:

  1. 三层卷积层和两个池化层,用于提取特征
  2. 五个批次归一化(Batch Normalization)层,用于加速训练过程
  3. 一个全连接层,用于输出分类结果

1.2、前向传播实现

接下来是forward函数,该函数定义了数据的前向传递流程。

  • 首先将输入x通过第一个卷积层conv1进行卷积,然后将卷积结果输入到第一个Batch Normalization层bn1中进行归一化处理,并通过激活函数F.relu进行非线性变换。
  • 接着将处理结果输入到第二个卷积层conv2中进行卷积,再将卷积结果输入到第二个Batch Normalization层bn2中进行归一化处理,并通过激活函数F.relu进行非线性变换。
  • 然后通过最大池化层pool进行下采样,缩小特征图的尺寸。然后将处理结果输入到第四个卷积层conv4中进行卷积,再将卷积结果输入到第四个Batch Normalization层bn4中进行归一化处理,并通过激活函数F.relu进行非线性变换。
  • 接着将处理结果输入到第五个卷积层conv5中进行卷积,再将卷积结果输入到第五个Batch Normalization层bn5中进行归一化处理,并通过激活函数F.relu进行非线性变换。
  • 最后通过最大池化层pool进行下采样,缩小特征图的尺寸。然后通过view函数将特征图展成一维向量,输入到全连接层fc1中进行分类。

具体来说,

  • x = F.relu(self.bn1(self.conv1(x)))表示先通过卷积层和Batch Normalization层提取特征,再进行ReLU激活;
  • x = F.relu(self.bn2(self.conv2(x)))同理;
  • x = self.pool(x)表示进行2x2最大池化操作;
  • x = F.relu(self.bn4(self.conv4(x)))同理;
  • x = F.relu(self.bn5(self.conv5(x)))同理;
  • x = self.pool(x)同理;
  • x = x.view(-1, 245050)将特征张量展平为一维向量;
  • x = self.fc1(x)表示通过全连接层得到分类结果。

1.3、原型网络结构

import torch.nn.functional as F

class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        
        # 第1层卷积层:输入有3个channel,输出有12个channel,卷积核大小为5x5,步长为1,填充为0
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
        # 第1个Batch Normalization层,有12个channel的输出
        self.bn1 = nn.BatchNorm2d(12)
        
        # 第2层卷积层:输入有12个channel,输出有12个channel,卷积核大小为5x5,步长为1,填充为0
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
        # 第2个Batch Normalization层,有12个channel的输出
        self.bn2 = nn.BatchNorm2d(12)
        
        # 第1个池化层,窗口大小为2x2,步长为2
        self.pool = nn.MaxPool2d(2,2)
        
        # 第3层卷积层:输入有12个channel,输出有24个channel,卷积核大小为5x5,步长为1,填充为0
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
        # 第3个Batch Normalization层,有24个channel的输出
        self.bn4 = nn.BatchNorm2d(24)
        
        # 第4层卷积层:输入有24个channel,输出有24个channel,卷积核大小为5x5,步长为1,填充为0
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        # 第4个Batch Normalization层,有24个channel的输出
        self.bn5 = nn.BatchNorm2d(24)
        
        # 全连接层:输入是24x50x50(24个24x24的feature map),输出是类别数
        self.fc1 = nn.Linear(24*50*50, len(classeNames))

    def forward(self, x):
        # 第1层卷积层+Batch Normalization层+ReLU激活函数
        x = F.relu(self.bn1(self.conv1(x)))      
        # 第2层卷积层+Batch Normalization层+ReLU激活函数
        x = F.relu(self.bn2(self.conv2(x)))     
        # 第1个池化层
        x = self.pool(x)                        
        # 第3层卷积层+Batch Normalization层+ReLU激活函数
        x = F.relu(self.bn4(self.conv4(x)))     
        # 第4层卷积层+Batch Normalization层+ReLU激活函数
        x = F.relu(self.bn5(self.conv5(x)))  
        # 第2个池化层
        x = self.pool(x)                        
        # 展开成一维张量
        x = x.view(-1, 24*50*50)
        # 全连接层
        x = self.fc1(x)
        
        return x
  • 模型打印:

在这里插入图片描述
其中,ReLU激活函数在卷积层中使用以增强模型的非线性特征。所有层之间都使用了Batch Normalization技术,可以有效地加速训练过程并提高模型的稳定性。

2、调整网络结构

2.1、提升精度方案

可以考虑尝试以下几个调整来提升网络精度:

  1. 增加卷积层的数量和大小:增加卷积层的数量和大小可以提高网络的感受野,从而更好地捕捉图像的特征。可以尝试增加一些卷积层和调整卷积核的大小。

  2. 调整池化层的大小:池化层可以减小数据的维度,但是过大的池化层会导致信息丢失。可以尝试使用更小的池化层或者不使用池化层。

  3. 使用更深的网络结构:可以尝试使用更深的网络结构,如ResNet、DenseNet等,这些网络结构能够更好地解决梯度消失问题,从而使得模型训练更加稳定。

  4. 数据扩充(data augmentation):可以通过对训练数据进行随机裁剪、旋转、翻转等操作来扩充数据集,从而提高模型的泛化能力。

  5. 使用预训练的模型进行微调:可以使用在大规模数据集上预训练的模型,在本任务上进行微调,能够更好地利用已有数据集的信息,提高模型的精度。

  6. 调整超参数:可以尝试调整学习率、批量大小、优化器等超参数来获得更好的结果。

  7. 使用更多的正则化技术:可以尝试使用 Dropout、L2 正则化等技术来减少过拟合。

  8. 增加数据集:可以尝试增加训练数据的数量,或者通过数据增强技术来生成更多的数据样本,以便网络可以更好地学习数据特征。

2.2、网络结构调整

调整后的模型包括若干个卷积层和池化层,并使用批归一化层进行正则化。在输入数据后,模型通过这些卷积和池化操作将其转换为特征向量,并使用全连接层对其进行分类。这个模型与原始的卷积神经网络不同之处在于:

  1. 添加了一个新的卷积层,用于提取更多的特征;
  2. 在模型中添加了一个批归一化层,可以加速训练并提高准确率;
  3. 修改了全连接层的输入大小,以适应新的卷积层。
class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(12),
            nn.ReLU(),
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(12),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(24),
            nn.ReLU(),
            nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(24),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(in_channels=24, out_channels=48, kernel_size=5, stride=1, padding=0), # 添加了一个新的卷积层
            nn.BatchNorm2d(48),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Linear(48 * 23 * 23, len(classNames)) # 修改了全连接层的输入大小

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 48 * 23 * 23)
        x = self.classifier(x)
        return x
  • 模型打印:

在这里插入图片描述

3、结果成效与对比

3.1、原模型结果可视化

在这里插入图片描述
在这里插入图片描述

3.2、调整后结果可视化

在这里插入图片描述
在这里插入图片描述

二、保存并加载模型

首先,将模型的参数保存到指定路径下的文件中。PATH 变量为保存的参数文件名,model.state_dict() 返回当前模型的参数字典,并使用 torch.save() 将其保存到 PATH 指定的文件中。

  • 定义一个文件路径PATH,用于保存模型参数
  • model.state_dict()函数返回一个字典对象,其中包含了模型中所有的可学习参数和缓存项
  • torch.save()函数将这个字典保存到指定的文件路径中
  • torch.load()函数将保存的参数加载回到模型中
  • model.load_state_dict()函数将加载的参数复制到模型中
# 定义保存参数的文件路径
PATH = './model.pth'

# 保存模型参数到指定文件路径
torch.save(model.state_dict(), PATH)

接着,可以在需要使用这些参数的时候,通过 torch.load() 方法将之前保存的参数加载回来。其中,map_location 参数指定了模型应该加载到哪个设备上(例如CPU或GPU)。

model.load_state_dict(torch.load(PATH, map_location=device))  # 将参数加载到模型实例中

总体而言,这段代码的作用是实现了对训练好的神经网络模型进行持久化存储并在需要的时候重新加载模型参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/493382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4。计算机组成原理(2)存储系统

嵌入式软件开发,非科班专业必须掌握的基本计算机知识 核心知识点:数据表示和运算、存储系统、指令系统、总线系统、中央处理器、输入输出系统 这一部分主要讲解了CPU的组成和扩容、CPU与存储器(主存、辅存、缓存)的连接 一 存储…

C++笔记——第十六篇 异常

目录 1.C语言传统的处理错误的方式 2. C异常概念 3. 异常的使用 3.1 异常的抛出和捕获 在函数调用链中异常栈展开匹配原则 3.2异常安全 4.异常的优缺点 1.C语言传统的处理错误的方式 传统的错误处理机制: 1. 终止程序,如assert,缺陷&a…

飞腾ft2000-麒麟V10-SP1安装Docker、运行gitlab容器

目录 一、安装及配置docker 1、卸载docker相关包及删除相关配置文件 2、安装二进制docker 1.下载软件包 2.解压 3.修改镜像加速地址 4.修改profile文件 5.启动docker 6.docker常用命令 二、安装并启动gitlab镜像 1.安装gitlab镜像 1.查询满足使用需求的gitlab版本 2…

很佩服的一个Google大佬,离职了。。

这两天,科技圈又有一个突发的爆款新闻相信不少同学都已经看到了。 那就是75岁的计算机科学家Geoffrey Hinton从谷歌离职了,从而引起了科技界的广泛关注和讨论。 而Hinton自己也证实了这一消息。 提到Geoffrey Hinton这个名字,对于一些了解过…

使用 Mercury 直接从 Jupyter 构建 Web 程序

动动发财的小手,点个赞吧! 有效的沟通在所有数据驱动的项目中都至关重要。数据专业人员通常需要将他们的发现和见解传达给利益相关者,包括业务领导、技术团队和其他数据科学家。 虽然传达数据见解的传统方法(如 PowerPoint 演示文…

Oracle SQL优化相关数据项

要掌握SQL调优技术,就需要能读懂SQL语句的执行计划,要想读懂SQL语句的执行计划,不仅需要准确理解SQL语句执行计划中各操作及其含义,还需要准确理解SQL语句执行计划中各数据项的含义。本书第7章中,已经对SQL语句执行计划中各个操作的含义做了详尽的阐述,本章中,我们将对S…

爱普特APT32F110x系列时钟介绍

最近要用APT32F110x做一些开发,顺便学习一下。 APT32F110x 是由爱普特推出的基于平头哥(T-Head Microsystems)CPU 内核开发的 32 位高性能低成本单片机。 APT32F1104x基于嵌入式 Flash 工艺制造,内部丰富的模拟资源,包…

ShardingJDBC核心概念与快速实战

目录 ShardingSphere介绍 ShardingSphere特点 ShardingSphere简述 ShardingSphere产品区分 ShardingJDBC实战 核心概念 实战 ShardingJDBC的分片算法 ShardingSphere目前提供了一共五种分片策略: 分库分表带来的问题 ShardingSphere介绍 ShardingSphere特…

结合SSE实现实时位置展示与轨迹展示

概述 实时位置与实时轨迹的展示是webgis中非常常见的一个功能,本文结合SSE来实现实现此功能。 SSE简介 SSE是Sever-Sent Event的首字母缩写,它是基于HTTP协议的,在服务器和客户端之间打开一个单向通道,服务端响应的不再是一次性…

车牌输入框 封装 (小程序 vue)

车牌输入框 封装 小程序licenseNumber.jslicenseNumber.jsonlicenseNumber.wxmllicenseNumber.wxss样例 vuevnp-input-box.vuevnp-input.vuevnp-keyboard.vue样例 小程序 licenseNumber.js const INPUT_NUM 8;//车牌号输入框个数 const EmptyArray new Array(INPUT_NUM).fi…

6个「会议议程」实例和免费模板

我们都参加过一些团队会议,在这些会议上,大多数与会者对会议的目的一无所知,而发言者则使讨论偏离轨道。 接下来就是一场真正的灾难了。 你会发现你的团队因为“上述会议”而浪费了很多时间,却没有达到任何目的。 好消息! 一个…

【Python】序列类型②-元组

文章目录 1.元组简介2.元组的定义2.1定义只有一个元素的元组 3.元组的下标访问4.元组的常用方法5.使用in判断是否存在元素6.多元赋值操作 1.元组简介 元组和列表一样可以存放多个,不同数据类型的元素 与列表最大的不同就是:列表是可变的,而元组不可变 2.元组的定义 元组的定义:…

TCP/UDP协议

一、协议的概念 什么是协议? 从应用的角度出发,协议可理解为“规则”,是数据传输和数据的解释的规则。 假设,A、B双方欲传输文件。规定: 第一次,传输文件名,接收方接收到文件名,…

Springboot +Flowable,ReceiveTask的简单使用方法

一.简介 ReceiveTask(接受任务),它的图标如下图所示: ReceiveTask 可以算是 Flowable 中最简单的一种任务,当该任务到达的时候,它不做任何逻辑,而是被动地等待用户确认。 ReceiveTask 往往适…

RepVGG: Making VGG-style ConvNets Great Again

文章地址:《RepVGG: Making VGG-style ConvNets Great Again》 代码地址:https://github.com/megvii-model/RepVGG 文章发表于CVPR2021,文章提出一种将训练态和推断态网络结构解耦的方法。文章认为目前复杂的网络结构能够获取更高的精度&am…

学大数据需要java学到什么程度

大数据需求越来越多,只有技术在手不愁找不到工作。 学习大数据需要掌握什么语言基础? 1、Java基础 大数据框架90%以上都是使用Java开发语言,所以如果要学习大数据技术,首先要掌握Java基础语法以及JavaEE方向的相关知识。 2、My…

记一次OJ在线代码编辑器(代码编译+运行,C、C++、Java)

如何在SpringBootVue的项目中实现在线代码编译及执行(支持编译运行C、C、Java),研究了一天,真实能用,下面直接上源码!!! ————————————————————————————…

MySQL 知识:迁移数据目录到其他路径

一、系统环境 操作系统:Centos 7 已安装环境:MySQL 8.0.26 二、开始操作 2.1 关闭SELinux 为了提高 Linux 系统的安全性,在 Linux 上通常会使用 SELinux 或 AppArmor 实现强制访问控制(Mandatory Access Control MAC&#xff…

中间件的概念

中间件(middleware)是基础软件的一大类,属于可复用的软件范畴。中间件在操作系统软件,网络和数据库之上,应用软件之下,总的作用是为处于自己上层的应用软件提供运行于开发的环境,帮助用户灵活、高效的开发和集成复杂的…

阶段二38_面向对象高级_网络编程[UDP单播组播广播代码实现]

知识: InetAddresss:getByName,getHostName,getHostAddress方法UDP通信程序:单播,组播,广播代码实现一.InetAddress 的使用 1.static InetAddress getByName(String host) 确定主机名称的IP地址。主机名称可以是机器名称&#x…