【Week-P4】CNN猴痘病识别

news2024/11/27 0:32:19

文章目录

  • 一、环境配置
  • 二、准备数据
  • 三、搭建网络结构
  • 四、开始训练
  • 五、查看训练结果
  • 六、总结
    • 2.3 ⭐`torch.utils.data.DataLoader()`参数详解
    • 6.1 `print()`常用的三种输出格式
    • 6.2 修改网络结构,观察训练结果
      • 6.2.1 增加pool2、conv6、bn6,test_accuracy=82.5%
      • 6.2.2 去掉pool2,保留conv6、bn6,增加conv7、bn7,test_accuracy=84.4%
      • 6.2.3 继续增加conv8、bn8、conv9、bn9,test_accuracy=87.2%
      • 6.2.4 继续增加conv10、bn10、conv11、bn11,test_accuracy=82.1%

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
    在这里插入图片描述
  • 本周的代码相对于上周增加指定图片预测与保存并加载模型这个两个模块,在学习这个两知识点后,时间有余的同学请自由探索更佳的模型结构以提升模型是识别准确率,模型的搭建是深度学习程度的重点。

一、环境配置

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets

import os,PIL,pathlib

import sys
from datetime import datetime
print("---------------------1.配置环境------------------")
print("Start time: ", datetime.today())
print("Pytorch version: " + torch.__version__)
print("Python version: " + sys.version)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

在这里插入图片描述

二、准备数据

2.1 打印classeNames列表,显示每个文件所属的类别名称
2.2 打印归一化后的类别名称,01
2.3 划分数据集,划分为训练集&测试集,torch.utils.data.DataLoader()参数详解
2.4 检查数据集的shape

  • 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象
  • 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
  • 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classNames
  • 第四步:打印classNames列表,显示每个文件所属的类别名称。
import os,PIL,random,pathlib
print("------------2.1 打印classeNames列表,显示每个文件所属的类别名称------------")
total_datadir= './4-data/'
data_dir = pathlib.Path(total_datadir)

data_paths = list(total_datadir.glob('*'))
classNames = [str(path).split("\\")[1] for path in data_paths]
print("classNames: ", classNames)

print("------------2.2 打印归一化后的类别名称,0或1------------")
# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
print("total_data: ", total_data)
print("total_data.class_to_idx: ", total_data.class_to_idx)

print("------------2.3 划分数据集,划分为训练集&测试集------------")
train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print( f"train_dataset: {train_dataset}, test_dataset: {test_dataset}")
print( f"train_size: {train_size}, test_size: {test_size}")

print("------------2.4 检查数据集的shape------------")
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=1)
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

在这里插入图片描述

三、搭建网络结构

print("------------3 搭建简单CNN网络------------")
import torch.nn.functional as F

class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        """
        nn.Conv2d()函数:
        第一个参数(in_channels)是输入的channel数量
        第二个参数(out_channels)是输出的channel数量
        第三个参数(kernel_size)是卷积核大小
        第四个参数(stride)是步长,默认为1
        第五个参数(padding)是填充大小,默认为0
        """
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2,2)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn5 = nn.BatchNorm2d(24)
        self.fc1 = nn.Linear(24*50*50, len(classNames))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))      
        x = F.relu(self.bn2(self.conv2(x)))     
        x = self.pool(x)                        
        x = F.relu(self.bn4(self.conv4(x)))     
        x = F.relu(self.bn5(self.conv5(x)))  
        x = self.pool(x)                        
        x = x.view(-1, 24*50*50)
        x = self.fc1(x)

        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = Network_bn().to(device)
model

在这里插入图片描述

四、开始训练

4.1 设置超参数
4.2 编写训练函数
4.3 编写测试函数
4.4 开始正式训练,epochs==20

print("------------4.1 设置超参数------------")
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)

print("------------4.2 编写训练函数------------")
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)   # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss
    
print("------------4.3 编写测试函数------------")
def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss
    
print("------------4.4 开始正式训练,epochs==20------------")
epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

在这里插入图片描述

五、查看训练结果

5.1 Loss与Accuracy图
5.2 指定图片进行预测
5.3 保存并加载模型

print("------------5.1 Loss与Accuracy图------------")
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

print("------------5.2 指定图片进行预测------------")
from PIL import Image 

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    # plt.imshow(test_img)  # 展示预测的图片

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
    
    model.eval()
    output = model(img)

    _,pred = torch.max(output,1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
    
# 预测训练集中的某张照片
predict_one_image(image_path='./4-data/Monkeypox/M01_01_00.jpg', 
                  model=model, 
                  transform=train_transforms, 
                  classes=classes)
                  
print("------------5.3 保存并加载模型------------")
# 模型保存
PATH = './model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)

# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))

在这里插入图片描述

六、总结

2.3 ⭐torch.utils.data.DataLoader()参数详解

torch.utils.data.DataLoaderPyTorch 中用于加载和管理数据的一个实用工具类。它允许你以小批次的方式迭代你的数据集,这对于训练神经网络和其他机器学习任务非常有用。DataLoader 构造函数接受多个参数,下面是一些常用的参数及其解释:

  1. dataset(必需参数):这是你的数据集对象,通常是 torch.utils.data.Dataset 的子类,它包含了你的数据样本。
  2. batch_size(可选参数):指定每个小批次中包含的样本数。默认值为 1
  3. shuffle(可选参数):如果设置为 True,则在每个 epoch 开始时对数据进行洗牌,以随机打乱样本的顺序。这对于训练数据的随机性很重要,以避免模型学习到数据的顺序性。默认值为 False
  4. num_workers(可选参数):用于数据加载的子进程数量。通常,将其设置为大于 0 的值可以加>快数据加载速度,特别是当数据集很大时。默认值为 0,表示在主进程中加载数据
  5. pin_memory(可选参数):如果设置为 True,则数据加载到 GPU 时会将数据存储在 CUDA 的锁页内存中,这可以加速数据传输到 GPU。默认值为 False
  6. drop_last(可选参数):如果设置为 True,则在最后一个小批次可能包含样本数小于 batch_size 时,丢弃该小批次。这在某些情况下很有用,以确保所有小批次具有相同的大小。默认值为 False
  7. timeout(可选参数):如果设置为正整数,它定义了每个子进程在等待数据加载器传递数据时的超时时间(以秒为单位),这可以用于避免子进程卡住的情况。默认值为 0,表示没有超时限制
  8. worker_init_fn(可选参数):一个可选的函数,用于初始化每个子进程的状态。这对于设置每个子进程的随机种子或其他初始化操作很有用。

6.1 print()常用的三种输出格式

    1. 带格式输出,{0}是指输出的第0个元素,同理{1}为第1个元素,{2}为第2个… 可以不按顺序排列
      print( "Hello {0}, I'm {2}, I,m {1} year old".format("world", age, name) )
    1. 使用类型输出,指定输出类型
      print( "I am %s, today is %d year"%(name, year) )
    1. f字符串,{}中为元素,是.format的简化形式
      print( f"Today is {year}")  

6.2 修改网络结构,观察训练结果

6.2.1 增加pool2、conv6、bn6,test_accuracy=82.5%

在这里插入图片描述
训练结果如下:
在这里插入图片描述
在这里插入图片描述
训练结果表明:修改网络结构之后,test_accuracy反而从85.3%降低到82.5%,说明此次修改结构不能提升test_accuracy的值。

6.2.2 去掉pool2,保留conv6、bn6,增加conv7、bn7,test_accuracy=84.4%

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
训练结果表明:与6.2.1的修改方法相比,test_accuracy从82.5%提升到84.4%。

6.2.3 继续增加conv8、bn8、conv9、bn9,test_accuracy=87.2%

在这里插入图片描述
训练情况如下:
在这里插入图片描述
在这里插入图片描述
训练结果表明:与6.2.2的修改方法相比,test_accuracy从84.4%提升到87.2%。

6.2.4 继续增加conv10、bn10、conv11、bn11,test_accuracy=82.1%

在这里插入图片描述
训练情况如下:
在这里插入图片描述
在这里插入图片描述
训练结果表明:与6.2.3的修改方法相比,test_accuracy从87.2%降低到82.1%。

综合上述4次修改,可以得出的结论是:适当增加conv、bn层可以有效提升test_accuracy,最好的效果是第三次修改,test_accuracy的值达到了87.2%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1354365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

postman使用-05新建测试集

文章目录 两种方式新建测试集测试集:允许用户以项目或模块的方式对多个接口进行分类和管理。每一个测试请求都可以被看作是一个独立的测试用例,而collections则可以同时管理多个测试用例的执行。方法一:点击左上角直接创建测试方法二&#xf…

ubuntu 执行apt-get update报错

系统是Ubuntu22.04 执行apt-get update 遇到如下情况 E: 无法下载 https://mirrors.tuna.tsinghua.edu.cn/ubuntu/dists/jammy/main/binary-arm64/Packages 404 Not Found [IP: 101.6.15.130 443] E: 无法下载 https://mirrors.tuna.tsinghua.edu.cn/ubuntu/dists/jammy-upda…

通灵术揭秘:空碗“竖筷子”不倒

通灵术揭秘:空碗“竖筷子”不倒 释名:竖筷子是流传很广的一种民间小术,因其法是在碗中竖起一支或三支筷子,故名。 用处:如果有人莫名其妙的生病了,医药无效,按民间的说法,就是遇鬼了…

Spark二、Spark技术栈之Spark Core

Spark Core spark核心:包括RDD、RDD算子、RDD的持久化/缓存、累加器和广播变量 学习链接:https://mp.weixin.qq.com/s/caCk3mM5iXy0FaXCLkDwYQ 一、 RDD 1.1 为什么要有RDD 在许多迭代式算法(比如机器学习、图算法等)和交互式数据挖掘中,…

基于SSM的校园快递管理系统

目录 前言 开发环境以及工具 项目功能介绍 学生: 管理员: 详细设计 获取源码 前言 本项目是一个基于IDEA和Java语言开发的基于SSM的校园快递管理系统应用。应用包含学生端和管理员端等多个功能模块。 欢迎使用我们的校园快递管理系统!我…

清风数学建模笔记-多分类-fisher线性判别分析

内容:Fisher线性判别分析 一.介绍: 1.给定的训练姐,设法投影到一维的直线上,使得同类样例的投影点尽可能接近和密集,异类投影点尽可能远离。 2.如何同类尽可能接近:方差越小 3.如何异类尽可能远离&#…

如何将Docker中的Tomact彻底删除

目录 前言: 一.删除Tomcat容器 列出所有在运行的容器信息 ​编辑 如果tomcat容器正在运行先停止,可以通过容器id或者容器名称 再次查看容器运行情况,可以看到没有运行中的容器了. 查看所有容器(-a表示查看所有)无…

MySQL取出N列里最大or最小的一个数据

如题,现在有3列,都是数字类型,要取出这3列里最大或最小的的一个数字 -- N列取最小 SELECT LEAST(temperature_a,temperature_b,temperature_c) min FROM infrared_heat-- N列取最大 SELECT GREATEST(temperature_a,temperature_b,temperat…

Basis Pursuit ADMM

c笔记 ref. distr_opt_stat_learning_admm.html Basis pursuit is the equality-constrained minimization problem In ADMM form, basis pursuit can be written as The ADMM algorithm is then The x-update, which involves solving a linearly-constrained minimu…

Vue v-html中内容图片过大自适应处理

之前图片如下&#xff0c;图片已经超出了页面的展示范围 对v-html增加样式处理 <div class"body padding-l scroll " v-html"docData.content"> </div><style scoped>.body >>> img {max-width: 100% ;} </style>…

XYZ世代

Z世代&#xff0c;Gen Zers&#xff0c;Generation Z &#xff0c;一词最早出现于欧美地区&#xff0c;是美国及欧洲的流行用语&#xff0c;泛指在1995-2009年间出生的一代人&#xff0c;千禧后一代。又称网络世代、互联网世代&#xff0c;网生代&#xff0c;二次元世代&#x…

【第一期】操作系统期末大揭秘:知识回顾与重点整理

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;网络奇遇记、数据结构 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. 操作系统概述1.1 操作系统定义1.2 操作系统的作用1.3 操作系统的功能1.4 操作…

机器人制作开源方案 | 核酸检测辅助机器人

作者&#xff1a;周文亚、胡冲、王晓强、张娟 单位&#xff1a;北方民族大学 指导老师&#xff1a;马行、穆春阳 1. 场景调研 新型冠状病毒肺炎全球流行已近三年&#xff0c;其变异毒株不断增强的传播力同时其症状不断变轻&#xff0c;其中无症状&#xff08;怎么确认是否被…

EM算法公式详细推导

EM算法是什么&#xff1f; EM算法是一种迭代算法&#xff0c;用于含隐变量概率模型参数的极大似然估计&#xff0c;或极大后验概率估计。EM算法由两步组成&#xff1a;E步&#xff0c;求期望&#xff1b;M步&#xff1a;求极大。EM算法的优点是简单性和普适性。 符号说明&…

QT基础知识

QT基础知识 文章目录 QT基础知识1、QT是什么2、Qt的发展史3、为什么学习QT4、怎么学习QT1、工程的创建(环境的下载与安装请百度&#xff09;2、创建的工程结构说明3、怎么看帮助文档1、类使用的相关介绍2. 查看所用部件&#xff08;类&#xff09;的相应成员函数&#xff08;功…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《基于碳捕集与封存-电转气-电解熔融盐协同的虚拟电厂优化调度》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主的专栏栏目《论文与完整程序》 这个标题涉及到多个关键概念&#xff0c;让我们逐一解读&#xff1a; 碳捕集与封存&#xff08;Carbon Capture and Storage&#xff0c;CCS&#xff09;&a…

【Linux】常用的基本命令指令①

前言&#xff1a;从今天开始&#xff0c;我们逐步的学习Linux中的内容&#xff0c;和一些网络的基本概念&#xff0c;各位一起努力呐&#xff01; &#x1f496; 博主CSDN主页:卫卫卫的个人主页 &#x1f49e; &#x1f449; 专栏分类:数据结构 &#x1f448; &#x1f4af;代码…

如何解决大模型的「幻觉」问题?

如何解决大模型的「幻觉」问题&#xff1f; 如何解决大模型的「幻觉」问题&#xff1f;幻觉产生原因&#xff1f;模型原因数据层面 幻觉怎么评估&#xff1f;Reference-based&#xff08;基于参考信息&#xff09;基于模型的输入、预先定义的目标输出基于模型的输入 Reference-…

基于ssm的资产管理信息系统+vue论文

摘要 当下&#xff0c;正处于信息化的时代&#xff0c;许多行业顺应时代的变化&#xff0c;结合使用计算机技术向数字化、信息化建设迈进。以前企业对于资产信息的管理和控制&#xff0c;采用人工登记的方式保存相关数据&#xff0c;这种以人力为主的管理模式已然落后。本人结…

《Linux C编程实战》笔记:实现自己的myshell

ok&#xff0c;考完试成功复活 这次是自己的shell命令程序的示例 流程图&#xff1a; 关键函数 1.void print_prompt() 函数说明&#xff1a;这个函数打印myshell提示符&#xff0c;即“myshell$$”. 2.void get_input(char *buf) 函数说明&#xff1a;获得一条指令&#…