卷积神经网络实现猴痘疾病图像分类 - P4

news2025/1/13 10:06:53
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:Pytorch实战 | 第P4周:猴痘病识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 全局对象
    • 数据准备
      • 数据集准备
      • 迭代对象准备
    • 模型设计
    • 模型训练
      • 训练函数
      • 评估函数
      • 模型训练
      • 保存最佳模型
    • 结果展示
      • 训练过程图示
      • 加载最佳模型
      • 随机选择一张图片进行预测
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118

步骤

环境设置

包引用

首先是引用依赖的包

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import pathlib, random
from PIL import Image
import numpy
import copy # 用来保存最佳模型
from torchinfo import summary #第三方包,用来打包模型实际的结构

然后是创建一个全局的设备对象,如果有显卡的话,使用显卡,创建全局对象是为了防止在运行过程中数据处于不同的设备中报错。

全局对象

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

数据集准备

数据集在原作者的项目中提供了下载,我下载到了项目目录中的data目录下。
数据集的文件结构如下

data
	Monkeypox
		images...
	Others
		images....

使用pathlib库加载数据集,读取分类名以及初步的展示一下数据集

data_path = 'data'
data_lib = pathlib.Path(data_path)
class_names = [f.parts[-1] for f in data_lib.glob('*')]
print(class_names)

数据集分类信息
随机抽取几个图像进行展示,这里我就不贴图片了,猴痘的图片非常的恶心,想看的话自己敲来看吧

image_list = list(data_lib.glob('*/*'))
plt.figure(figsize=(20,4))
for i in range(20):
	plt.subplot(2, 10, i+1)
	image_path = random.choice(image_list)
	image = Image.open(str(image_path))
	print(np.array(image).shape)
	plt.imshow(image)
	plt.axis('off')
	plt.title(image_path.parts[-2]) # 取图片的上级目录名,其实就是分类名称

图像尺寸

迭代对象准备

通过打印可以发现,图片的尺寸一致为224x224,这样我们就不需要对图片进行Resize操作了。接下来我们创建pytorch的数据集

dataset =  datasets.ImageFolder(data_path, transform=transforms.ToTensor())

train_len = int(len(dataset) * 0.8)
test_len = len(dataset) - train_len

train_dataset, test_dataset = random_split(dataset, [train_len, test_len])

batch_size = 64 # 数据的批次大小
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

使用大小为3x3的卷积核加最大池化来设计模型,经历了4个Conv-BN-ReLU-Pool操作后,再加上两层全连接层。
模型结构图

class Network(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		# 224 -> 222
		self.conv1 = nn.Conv2d(3, 16, 3)
		self.bn1 = nn.BatchNorm2d(16)
		# 222 -> 111
		self.maxpool = nn.MaxPool2d(2)
		# 111 -> 109
		self.conv2 = nn.Conv2d(16, 32, 3)
		self.bn2 = nn.BatchNorm2d(32)
		# 109 -> 54 -> 52
		self.conv3 = nn.Conv2d(32, 64, 3)
		self.bn3 = nn.BatchNorm2d(64)
		# 52 -> 26 -> 24
		self.conv4 = nn.Conv2d(64, 64, 3)
		self.bn4 = nn.BatchNorm2d(64)
		# 24 -> 12
		self.fc1 = nn.Linear(64*12*12, 128)
		self.fc2 = nn.Linear(128, num_classes)
		self.dropout = nn.Dropout(0.4)
	def forward(self, x):
		x = F.relu(self.bn1(self.conv1(x)))
		x = self.maxpool(x)
		x = F.relu(self.bn2(self.conv2(x)))
		x = self.maxpool(x)
		x = F.relu(self.bn3(self.conv3(x)))
		x = self.maxpool(x)
		x = F.relu(self.bn4(self.conv4(x)))
		x = self.maxpool(x)
		x = x.view(x.size(0), -1)
		x = self.dropout(x)
		x = F.relu(self.fc1(x))
		x = self.dropout(x)
		x = F.softmax(self.fc2(x), dim =1)
		return x
model = Network(len(class_names)).to(device)
summary(model, (batch_size, 3, 224, 224))

模型结构

模型训练

训练过程中,每个迭代都会跑一次训练和一次评估,因此将训练和评估的过程先写两个函数封装一下

训练函数

def train(model, train_loader, loss_fn, optimizer):
	data_len = len(train_loader.dataset)
	batch_len = len(train_loader)

	train_loss, train_acc = 0, 0
	for x, y in train_loader:
		x, y = x.to(device), y.to(device)

		pred = model(x)
		loss = loss_fn(pred, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	train_loss /= batch_len
	train_acc /= data_len

	return train_loss, train_acc

评估函数

def test(model, test_loader, loss_fn):
	data_len = len(test_loader.dataset)
	batch_len = len(test_loader)
	
	test_loss, test_acc = 0, 0
	for x, y in test_loader:
		x, y = x.to(device), y.to(device)
		
		pred = model(x)
		loss = loss_fn(pred, y)
		test_loss += loss.item()
		test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	test_loss /= batch_len
	test_acc /= data_len

	return test_loss, test_acc

模型训练

在训练之前,首先需要定义好损失函数,学习率优化器等对象,这里我还增加了一个动态修改学习率的scheduler,防止模型在最优解附近不收敛。

epochs = 50
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.92**(epoch//4)) # 定义一个梯度衰减

开始训练

train_loss, train_acc = [], [] # 记录训练的历史过程数据
test_loss, test_acc = [], [] # 记录验证的历史过程数据
best_acc = 0
for epoch in range(epochs):
	model.train()
	epoch_train_loss, epoch_train_acc = train(model, train_loader, loss_fn, optimizer)
	model.eval()
	with torch.no_grad():
		epoch_test_loss, epoch_test_acc = test(model, test_loader, loss_fn)
	
	train_loss.append(epoch_train_loss)
	train_acc.append(epoch_train_acc)
	test_loss.append(epoch_test_loss)
	test_acc.append(epoch_test_acc)

	scheduler.step() # 每个epoch结束后调用一次梯度衰减,就会触发一次Lambda

	if epoch_test_acc > best_acc:
		best_acc = epoch_test_acc
		best_model = copy.deepcopy(model) # 将正确率最高的模型深拷贝一份,保存下来

	print(f"Epoch {epoch+1}, TrainLoss: {epoch_train_loss:.3f}, TrainAcc: {epoch_train_acc*100:.1f}, ValLoss: {epoch_test_loss:.3f}, ValAcc: {epoch_test_acc*100:.1f}")
print(f"Done. the best accuracy is {best_acc}")

训练过程
经过50个epoch的迭代,模型的正确率最好可以达到91.3%

保存最佳模型

torch.save(best_model.state_dict(), 'best_model.pth')

结果展示

训练过程图示

使用matplotlib的折线图,打印训练过程中,训练集和验证集上的损失和正确率

data_range = range(epochs)
plt.figure(figsize=(20,5))
plt.subplot(1,2, 1)
plt.plot(data_range, train_loss, label='train loss')
plt.plot(data_range, test_loss, label='validation loss')
plt.title('loss')
plt.legend(loc='upper right')

plt.subplot(1,2,2)
plt.plot(data_range, train_acc, label='train accuracy')
plt.plot(data_range, test_acc, label='validation accuracy')
plt.title('accuracy')
plt.legend(loc='lower right')

训练过程图示

加载最佳模型

model.load_state_dict(torch.load('best_model.pth', map_location=device))

随机选择一张图片进行预测

image_path = random.choice(image_list)
raw_image = Image.open(str(image_path))
image = transforms.ToTensor()(raw_image)
image = image.unsqueeze(0).to(device)
pred = model(image)

plt.figure(figsize=(5,5))
plt.axis('off')
plt.imshow(raw_image)
plt.title(class_names[pred.argmax(1).cpu()])

结果就不打印了,无法直视。


总结与心得体会

  1. 因为在模型的调试阶段要跑很多次,所以已经要注意每次重新跑的时候重置一下模型的权重(我一般直接重新定义模型)。不然可能无法对比出改动对当前任务是否有效,一定要注意。
  2. 刚开始模型训练了20个epoch,通过折线图可以看出验证集上的正确率还处于上升状态,于是将训练的epoch修改为50,可以看出,训练的末期验证集上的正确率已经不再增长,甚至有所下降,说明模型已经收敛,此时的正确率如果还不符合要求,就需要对模型结构进行改进了。
  3. 当训练集上的损失和正确率与验证集上的损失和正确率相差很大时,说明模型发生了过拟合现象,增大全连接层的Dropout比例,可以一定程度上抑制过拟合现象。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/927231.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python数据分析实战-对列表里面的元素绘制词云图(附源码和实现效果)

实现功能 词云也叫文字云,是一种可视化的结果呈现,原理就是统计文本中高频出现的词,将结果生成一张图片,直观的获取数据的重点信息 实现代码 from wordcloud import WordCloud import matplotlib.pyplot as plt# 假设你的字符串…

浅谈大数据智能审计如何助力审计工作

随着互联网大数据的持续发展,大数据审计近年来面对着相等的机遇和挑战。那么,如果利用大数据等相关技术对审计工作作出突出贡献,单位和企业又该从何入手做好大数据审计工作应用,这些都成为每位审计人员将要面临的重要问题。 1. 政…

电机控制软件框架

应用层包括main 主函数模块,ISR 中断处理函数模块、时基Systick 模块和BLDC 应用接口模块;算法层包括BLDC Algorithm 模块和PID control 模块;驱动层(Driver layer):包括GD32Fxx_Standard_peripheral libra…

基于SSM的OA办公系统Java企业人事信息管理jsp源代码MySQL

本项目为前几天收费帮学妹做的一个项目,Java EE JSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。 一、项目描述 基于SSM的OA办公系统 系统有1权限:管理员…

课程项目设计--spring security--认证管理功能--宿舍管理系统--springboot后端

写在前面: 还要实习,每次时间好少呀,进度会比较慢一点 本文主要实现是用户管理相关功能。 前文项目建立 文章目录 验证码功能验证码配置验证码生成工具类添加依赖功能测试编写controller接口启动项目 security配置拦截器配置验证码拦截器 …

台积电美国厂施工现场混乱,真令人头痛 | 百能云芯

近日,英伟达公司的财报表现异常亮眼,摩根士丹利不仅点名了台积电成为最大的受益者,还预测每售出一颗H100英伟达芯片,台积电就能获得900美元的利润。然而,美国媒体却曝出了一则不利的消息,称美国亚利桑那州的…

Hbase文档--架构体系

阿丹: 基础概念了解之后了解目标知识的架构体系,就能事半功倍。 架构体系 关键组件介绍: HBase – Hadoop Database,是一个高可靠性、高性能、面向列、可伸缩的分布式存储系统,利用HBase技术可在廉价PC Server上搭建起…

vue3 vite使用 monaco-editor 报错

报错:Unexpected usage at EditorSimpleWorker.loadForeignModule 修改配置: "monaco-editor-webpack-plugin": "^4.2.0",删除不用 版本: "monaco-editor": "^0.28.1", 修改如下: opti…

Python“牵手”虾皮(Shopee)商品详情API接口运用场景及功能介绍,虾皮API接口申请指南

虾皮(Shopee)电商API接口是针对虾皮提供的电商服务平台,为开发人员提供了简单、可靠的技术来与虾皮电商平台进行数据交互,实现一系列开发、管理和营销等操作。其中包括商品详情API接口,通过这个API接口商家可以获取商品…

SLS日志解析配置

分隔符模式 INFO|2023-04-10T11:05:30.12808:00|X.X.X.X|ACCESS_ALLOWED|1 模式:分隔符模式 日志样例:贴文档说明中的样例,或者直接在SLS历史日志里找一行 分隔符:竖线 日志抽取内容Key用文档中说明的变量名 是否接受部分字段&am…

CMIP6中的模式比较计划、CMIP6数据下载、单点降尺度、统计方法的区域降尺度、基于WRF模式的动力降尺度及气候变化、生态、水文等典型案例

CMIP6数据被广泛应用于全球和地区的气候变化研究、极端天气和气候事件研究、气候变化影响和风险评估、气候变化的不确定性研究、气候反馈和敏感性研究以及气候政策和决策支持等多个领域。这些数据为我们理解和预测气候变化,评估气候变化的影响和风险,以及…

Docker容器与虚拟化技术:Harbor私有仓库部署与迁移

目录 一、理论 1.本地私有仓库 2.Harbor 二、实验 1.Docker搭建本地私有仓库 2.docker-compose部署及配置 3.harbor部署及配置 4.登录创建项目 5.在其他客户端上传镜像 6. harbor维护 7.移除 Harbor 服务容器同时保留镜像数据/数据库,并进行迁移 三、问题…

把Android手机变成电脑摄像头

一、使用 DroidCam 使用 DroidCam,你可以将手机作为电脑摄像头和麦克风。一则省钱,二则可以在紧急情况下使用,比如要在电脑端参加一个紧急会议,但电脑却没有摄像头和麦克风。 DroidCam 的安卓端分为免费的 DroidCam 版和收费的 …

服务器数据恢复-服务器RAID6硬盘故障离线的数据恢复案例

服务器数据恢复环境: 服务器中有一组由6块磁盘组建的RAID6磁盘阵列。服务器作为WEB服务器使用,上面运行了MYSQL数据库以及存放了网站代码和其他数据文件。 服务器故障: 在服务器运行过程中该raid6阵列中有两块磁盘先后离线,但是管…

火山引擎DataLeap基于Apache Atlas自研异步消息处理框架

更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 字节数据中台DataLeap的Data Catalog系统通过接收MQ中的近实时消息来同步部分元数据。Apache Atlas对于实时消息的消费处理不满足性能要求,内部使用Flin…

Mysql索引、事务与存储引擎 (索引)

一、索引 1、索引的概念: 索引就是一种帮助系统能够更加快速的查找信息的数据结构。 2.索引的作用: ①数据库利用各种快速定位技术,能够大大加快查询速度,这是创建索引的最主要的原因。 ②当表很大或查询涉及到多个表时&#xff0…

Linux安装软件每次靠百度,这次花了些时间,终于算是搞明白了

Linux下安装命令虽然经常使用,但也仅仅是会使用,每次再用时依然的百度 。于是就花了些时间整体的梳理了一番,以便于更好的理解。 1.安装流程介绍 在Linux下安装软件,其实也是遵循着和Windows一样的安装流程。 首先,…

巨人互动|Facebook海外户Facebook游戏全球发布实用策略

Facebook是全球最大的社交媒体平台之一,拥有庞大的用户基数和广阔的市场。对于游戏开发商而言,利用Facebook进行全球发布是一项重要的策略。下面小编将介绍一些实用的策略帮助开发商在Facebook上进行游戏全球发布。 巨人互动|Facebook海外户&Faceboo…

淘宝API技术解析,实现按图搜索淘宝商品

淘宝提供了开放平台接口(API)来实现按图搜索淘宝商品的功能。您可以通过以下步骤来实现: 1. 获取开放平台的访问权限:首先,您需要在淘宝开放平台创建一个应用,获取访问淘宝API的权限。具体的申请步骤和要求…

1.6 服务器处理客户端请求

客户端进程向服务器进程发送一段文本(MySQL语句),服务器进程处理后再向客户端进程发送一段文本(处理结果)。 从图中我们可以看出,服务器程序处理来自客户端的查询请求大致需要经过三个部分,分别…