卷积神经网络实现天气图像分类 - P3

news2025/4/15 3:21:55
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:Pytorch实战 | 第P3周:彩色图片识别:天气识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
    • 数据准备
    • 模型设计
    • 模型训练
    • 结果展示
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pyto2.0.0+cu118

步骤

环境设置

首先是包引用

import torch # pytorch主包
import torch.nn as nn # 模型相关的包,创建一个别名少打点字
import torch.optim as optim # 优化器包,创建一个别名
import torch.nn.functional as F # 可以直接调用的函数,一般用来调用里面在的激活函数

from torch.utils.data import DataLoader, random_split # 数据迭代包装器,数据集切分
from torchvision import datasets, transforms # 图像类数据集和图像转换操作函数

import matplotlib.pyplot as plt # 图表库
from torchinfo import summary # 打印模型结构

查询当前环境的GPU是否可用

print(torch.cuda.is_available())

GPU可用情况
创建一个全局的设备对象,用于使各类数据处于相同的设备中

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 当GPU不可用时,使用CPU

# 如果是Mac系统可以多增加一个if条件,启用mps
if torch.backends.mps.is_available():
	device = torch.device('mps')

数据准备

这次的天气图像是由K同学提供的,我提前下载下来放在了当前目录下的data文件夹中
加载文件夹中的图像数据集,要求文件夹按照不同的分类并列存储,一个简要的文件树为

data
	cloudy
	rain
	shine
	sunrise

使用torchvisio.datasets中的方法加载自定义图像数据集,可以免除一些文章中推荐的自己创建Dataset,个人感觉十分方便,而且这种文件的存储结构也兼容keras框架。

首先我们使用原生的PythonAPI来遍历一下文件夹,收集一下分类信息

import pathlib

data_lib = pathlib.Path('data')
class_names = [f.parts[-1] for f in data_lib.glob('*')] # 将data下级文件夹作为分类名
print(class_names)

打印分类信息
在所有的图片中随机选择几个文件打印一下信息。

import numpy as np
from PIL import Image
import random

image_list = list(data_lib.glob('*/*'))
for _ in range(10):
	print(np.array(Image.open(random.choice(image_list))).shape)

打印图像信息
通过打印图像信息,发现图像的大小并不一致,需要在创建数据集时对图像进行缩放到统一的大小。

transform = transforms.Compose([
	transforms.Resize([224, 224]), # 将图像都缩放到224x224
	transforms.ToTensor(), # 将图像转换成pytorch tensor对象
]) # 定义一个全局的transform, 用于对齐训练验证以及测试数据

接下来就可以正式从文件夹中加载数据集了

dataset = datasets.ImageFolder('data', transform=tranform)

现在把整文件夹下的所有文件加载为了一个数据集,需要根据一定的比例划分为训练和验证集,方便模型的评估

train_size = int(len(dataset) *0.8) # 80% 训练集 20% 验证集
eval_size = len(dataset) - train_size

train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

创建完数据集,打印一下数据集中的图像

plt.figure(figsize=(20, 4))
for i in range(20):
	image, label = train_dataset[i]
	plt.subplot(2, 10, i+1)
	plt.imshow(image.permute(1,2,0)) # pytorch的tensor格式为N,C,H,W,在imshow展示需要将格式变成H,W,C格式,使用permute切换一下
	plt.axis('off')
	plt.title(class_names[label])

预览数据集
最后用DataLoader包装一下数据集,方便遍历

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_loader, batch_size=batch_size)

模型设计

使用一个带有BatchNorm的卷积神经网络来处理分类问题

class Network(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		self.conv1 = nn.Conv2d(3, 12, kernel_size=5, strides=1)
		self.conv2 = nn.Conv2d(12, 12, kernel_size=5, strides=1)
		self.conv3 = nn.Conv2d(12, 24, kernel_size=5, strides=1)
		self.conv4 = nn.Conv2d(24, 24, kernel_size=5, strides=1)

		self.maxpool = nn.MaxPool2d(2)

		self.bn1 = nn.BatchNorm2d(12)
		self.bn2 = nn.BatchNorm2d(12)
		self.bn3 = nn.BatchNorm2d(24)
		self.bn4 = nn.BatchNorm2d(24)

		# 224 [-> 220 -> 216 -> 108] [-> 104 -> 100 -> 50]
		self.fc1 = nn.Linear(50*50*24, num_classes)
	
	def forward(self, x):
		x = F.relu(self.bn1(self.conv1(x)))
		x = F.relu(self.bn2(self.conv2(x)))
		x = self.maxpool(x)
		x = F.relu(self.bn3(self.conv3(x)))
		x = F.relu(self.bn4(self.conv4(x)))
		x = self.maxpool(x)
	
		x = x.view(x.size(0), -1)
		
		x = self.fc1(x)

		return x

model = Network(len(class_names)).to(device) # 别忘了把定义的模型拉入共享中
summary(model, input_size=(32, 3, 224, 224))

模型结构

模型训练

首先定义一下每个epoch内训练和评估的逻辑

def train(train_loader, model, loss_fn, optimizer):
	train_size = len(train_loader.dataset)
	num_batches = len(train_loader)

	train_loss, train_acc = 0, 0
	for x, y in train_loader:
		x, y = x.to(device), y.to(device)
		
		preds = model(x)
		loss = loss_fn(preds, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()
	train_loss /= num_batches
	train_acc /= train_size
	return train_loss, train_acc
	
def eval(eval_loader, model, loss_fn):
	eval_size = len(eval_loader.dataset)
	num_batches = len(eval_loader)
	eval_loss, eval_acc = 0, 0
	for x, y in eval_loader:
		x, y = x.to(device), y.to(device)

		preds = model(x)
		loss = loss_fn(preds, y)

		eval_loss += loss.item()
		eval_acc += (preds.argmax(1) == y).type(torch.float).sum().item()
	eval_loss /= num_batches
	eval_acc /= eval_size

	return eval_loss, eval_acc

然后编写代码进行训练

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 10

train_loss, train_acc = [], []
eval_loss, eval_acc =[], []
for epoch in range(epochs):
	model.train()
	epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
	model.eval()
	model.no_grad():
		epoch_eval_loss, epoch_eval_acc = test(eval_loader, model, loss_fn)

结果展示

训练结果
基于训练和测试数据展示结果

range_epochs = range(len(train_loss))
plt.figure(figsize=(12, 4))
plt.subplot(1,2,1)
plt.plot(range_epochs, train_loss, label='train loss')
plt.plot(range_epochs, eval_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')

plt.subplot(1,2,2)
plt.plot(range_epochs, train_acc, label='train accuracy')
plt.plot(range_epochs, eval_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练历史图表

总结与心得体会

通过对训练过程的观察,训练过程中的数据波动很大,并且验证集上的最好正确率只有82%。
目前行业都流行小卷积核,于是我把卷积核调整为了3x3,并且每次卷积后我都进行池化操作,直到通道数为64,由于天气识别时,背景信息也比较重要,高层的卷积操作后我使用平均池化代替低层使用的最大池化,加大了全连接层的Dropout惩罚比重,用来抑制过拟合问题。最后的模型如下:

class Network(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3)
        
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.maxpool = nn.MaxPool2d(2)
        self.avgpool = nn.AvgPool2d(2)
        self.dropout = nn.Dropout(0.5)
        
        # 224 -> 222-> 111 -> 109 -> 54 -> 52 -> 50 -> 25
        self.fc1 = nn.Linear(25*25*64, 128)
        self.fc2 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.avgpool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.avgpool(x)
        
        
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

然后增大训练的epochs为30,学习率降低为1e-4

optimizer = optim.Adam(model.parameters(), lr=1e-4)
epochs = 30

训练结果如下
训练过程
可以看到,验证集上的正确率最高达到了95%以上
训练过程图示

在数据集中随机选取一个图像进行预测展示

image_path = random.choice(image_list)
image_input = transform(Image.open(image_path))
image_input = image_input.unsqueeze(0).to(device)
model.eval()
pred = model(image_input)

plt.figure(figsize=(5, 5))
plt.imshow(image_input.cpu().squeeze(0).permute(1,2,0))
plt.axis('off')
plt.title(class_names[pred.argmax(1)])

结果如下
预测结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/911635.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

问道管理:中国十大科技板块?

跟着科技的开展,各种高科技工业在我国迅猛开展,其中十大板块就是一个比较典型的代表。这十大科技板块涵盖了从电子信息、生命健康到新材料等多个范畴,让我们一起来了解一下这十大板块的开展现状。 一、电子信息 作为国家重点支持开展的工业之…

剑指offer(C++)-JZ64:求1+2+3+...+n(算法-位运算)

作者:翟天保Steven 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 题目描述: 求123...n,要求不能使用乘除法、for、while、if、else、switch、case等关键字及条件判断语句&…

如何使用数学将 NumPy 函数的性能提高 50%

一、说明 2D 傅里叶变换是本世纪最重要的计算机科学算法之一。它已在我们的日常生活中得到应用,从Instagram过滤器到MP3文件的处理。 普通用户最常用的实现,有时甚至是在不知不觉中,是 NumPy 的改编。然而,尽管它很受欢迎&#xf…

CDH集群离线配置python3环境,并安装pyhive、impyla、pyspark

背景: 项目需要对数仓千万级数据进行分析、算法建模。因数据安全,数据无法大批量导出,需在集群内进行分析建模,但CDH集群未安装python3 环境,需在无网情况下离线配置python3环境及一系列第三方库。 采取策略&#xf…

python分析实战(4)--获取某音热榜

1. 分析需求 打开某音热搜,选择需要获取的热榜如图 查找包含热搜内容的接口返回如图 将url地址保存 2. 开发 定义请求头 headers {Cookie: 自己的cookie,Accept: application/json, text/plain, */*,Accept-Encoding: gzip, deflate,Host: www.douyin.com,…

vue3+element下拉多选框组件

<!-- 下拉多选 --> <template><div class"select-checked"><el-select v-model"selected" :class"{ all: optionsAll, hidden: selectedOptions.data.length < 2 }" multipleplaceholder"请选择" :popper-app…

C++信息学奥赛1129:统计数字字符个数

这段代码的功能是计算一个输入字符串中的数字字符个数。 解析注释后的代码如下&#xff1a; #include<bits/stdc.h> using namespace std; int main() {string arr; // 定义字符串变量arr&#xff0c;用来存储输入的字符串getline(cin, arr); // 通过getline函数输入完…

企业文件透明加密软件——「天锐绿盾」数据防泄密管理软件系统

PC访问地址&#xff1a; 首页 一、文档透明加密软件 文档透明加密功能&#xff1a;在不影响单位内部员工对电脑任何正常操作的前提下&#xff0c;文档在复制、新建、修改时被系统强制自动加密。文档只能在单位内部电脑上正常使用&#xff0c;在外部电脑上使用是乱码或无法打…

前端通信(渲染、http、缓存、异步、跨域)自用笔记

SSR/CSR&#xff1a;HTML拼接&#xff1f;网页源码&#xff1f;SEO/交互性 SSR &#xff08;server side render&#xff09;服务端渲染&#xff0c;是指由服务侧&#xff08;server side&#xff09;完成页面的DOM结构拼接&#xff0c;然后发送到浏览器&#xff0c;为其绑定状…

Qt+C++串口调试接收发送数据曲线图

程序示例精选 QtC串口调试接收发送数据曲线图 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<QtC串口调试接收发送数据曲线图>>编写代码&#xff0c;代码整洁&#xff0c;规则&…

为何lazada、亚马逊、速卖通卖家都选择自养账号测评?

无论是做亚马逊还是shopee、Lazada、速卖通、wish、煤炉、拼多多Temu、敦煌网、eBay、Etsy、Newegg、美客多、Allegro、阿里国际、poshmark、沃尔玛、joom、OZON等平台。如果想要销量好&#xff0c;免不了进行补单测评的&#xff0c;因为不管对于哪一个平台的店铺新产品而言&am…

探工业互联网的下一站!腾讯云助力智造升级

引言 数字化浪潮正深刻影响着传统工业形态。作为第四次工业革命的重要基石&#xff0c;工业互联网凭借其独特的价值快速崛起&#xff0c;引领和推动着产业变革方向。面对数字化时代给产业带来的机遇与挑战&#xff0c;如何推动工业互联网的规模化落地&#xff0c;加速数字经济…

开利网络受邀参与御盛马术庄园发展专委会主题会议

近日&#xff0c;开利网络受邀参与深度合作客户御盛马术庄园组织的首届发展专委会主体会议&#xff0c;就马术庄园发展方向进行沟通&#xff0c;数字化也是重要议题之一。目前&#xff0c;御盛马术庄园已经完成数字化系统的初步搭建&#xff0c;将通过线上线下相结合的方式搭建…

编写接口文档示例:从零开始,轻松掌握关键技巧

接口文档的编写是软件开发中至关重要的一环&#xff0c;本文将详细介绍如何编写接口文档示例&#xff0c;为您揭示从基础知识到高级技巧的全过程。通过实用的指导和比喻&#xff0c;让您轻松掌握编写接口文档示例的艺术。 在现代软件开发中&#xff0c;编写接口文档示例是确保项…

Linux 上 离线部署GeoScene Server Py3 运行时环境

默认安装ArcGIS Pro的时候&#xff0c;会自动部署上Python3环境&#xff0c;所以在windows上不需要考虑这个问题&#xff0c;但是linux默认并不部署Py3&#xff0c;因此需要单独部署&#xff0c;具体部署可以参考Linux 上 ArcGIS Server 的 Python 3 运行时—ArcGIS Server | A…

PAT(Advanced Level) Practice(with python)——1067 Sort with Swap(0, i)

Code # 输入有毒&#xff0c;需避坑 # N int(input()) L list(map(int,input().split())) N L[0] L L[1:] res 0 for i in range(1,N):while L[0]!0:# 把所有不在正常位置下的数换到正常t L[0]L[0],L[t] L[t],L[0]res1if L[i]!i:# 换完全后如果对应位置下的数不是目标…

【校招VIP】测试专业课之TCP/IP模型

考点介绍&#xff1a; 大厂测试校招面试里经常会出现TCP/IP模型的考察&#xff0c;TCP/IP协议是网络基础知识&#xff0c;但是在校招面试中很多同学在基础回答中不到位&#xff0c;或者倒在引申问题里&#xff0c;就丢分了。 『测试专业课之TCP/IP模型』相关题目及解析内容可点…

免费开源CRM:有哪些免费开源的CRM系统可供选择?

CRM系统是什么 CRM就是客户关系管理系统&#xff0c;简单来说&#xff0c;就是一个要做到集客户管理&#xff0c;产品进销存&#xff0c;订单跟进&#xff0c;数据分析&#xff0c;售后维护为一体的系统。而开源的CRM系统&#xff0c;最基本的含义是代码是公开的&#xff0c;任…

innovus添加pad的命令

我正在「拾陆楼」和朋友们讨论有趣的话题&#xff0c;你⼀起来吧&#xff1f; 拾陆楼知识星球入口 innovus中添加pad需要使用addInst命令创建physical cell&#xff0c;因为pad没有逻辑功能&#xff0c;不存在与网表中&#xff0c;需要自己创建。 添加pad用addInst -inst $pa…

pr剪辑工具绿色版本

免费提供 提取码: 402s Pr提供了采集、剪辑、调色、美化音频、字幕添加、 输出、DVD刻录等一整套流程&#xff0c; 并和其他Adobe 软件高效集成&#xff0c;使您足以完成在编辑、 制作、工作上遇到的所有挑战&#xff0c;满足您创建高质量作品的要求。 Pr的版本选择 常用…