Pytorch入门实战 P3-天气识别

news2024/11/13 19:19:27

 

目录

一、前期准备

1、查看设备

2、导入本地数据

3、测试下获取到的天气数据

4、图像预处理

5、划分数据集

6、加载数据集

二、搭建简单的CNN网络(特征提取+分类)

三、训练模型

1、设置超参数

2、编写训练函数

3、编写测试函数

4、正式训练

5、可视化结果

6、训练结果截图对比

7、模型保存


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

今天这篇文章,主要是用来对天气的图片进行识别的。能够训练预测天气模型,然后对预测出的天气模型,使测试率达到90%以上。

基本的流程同之前两篇的类似,区别在于,这篇文章是获取的是本地的图片数据,不是之前以往的在线下载的数据。

一、前期准备

1、查看设备

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

2、导入本地数据

将提前准备好的文件夹,放在和train.py文件同级目录

# 导入数据
data_dir = './weather_photos'
data_dir = pathlib.Path(data_dir)  # 返回的结果为: weather_photos

data_paths = list(data_dir.glob('*'))  # 得到weather_photos下面的所有文件夹路径。例如:weather_photos/sunrise,weather_photos/rain
classNames = [str(path).split('/')[1] for path in data_paths]  # 将上面获取到的data_paths进行分割。
print(classNames)  # 返回的结果为:['cloudy', 'rain', 'shine', 'sunrise']

3、测试下获取到的天气数据

# 展示获取到的天气数据
# 指定图像文件夹路径
image_folder = './weather_photos/cloudy/'
# 获取文件夹中的所有图像文件(.endswith 检查每个文件名f是否以指定的后缀结尾)
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png', 'jpeg'))]
# 创建Matplotlib图像
fig, axes = plt.subplots(3, 8, figsize=(16, 6))  # 创建一个3行8列的网格,设置每个网格的大小为16x6

# 使用列表式,加载和显示图像。
for ax, img_file in zip(axes.flat, image_files):
    img_path = os.path.join(image_folder, img_file)   # os.path.join() 用于连接路基
    img = Image.open(img_path)  # Image.open() 用于打开一个图像文件并返回一个图像对象。
    ax.imshow(img)   # 在子图上显示图像
    ax.axis('off')  # 关闭坐标
# 显示图像
plt.tight_layout()  # 使用plt.tight_layout() 来自动调整子图布局
plt.show()

4、图像预处理

# 图像预处理
train_transforms = transforms.Compose([
    transforms.Resize([224, 244]),  # 图像裁剪,调整大小为224x224
    transforms.ToTensor(),   # 数据类型转化,将PIL图像或numpy数组转换为Pytorch的张量;像素值缩放,[0,255]→[0.0,1.0]
    transforms.Normalize(   # 标准化处理
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

5、划分数据集

# 总的数据集
total_data = torchvision.datasets.ImageFolder('./weather_photos',
                                              transform=train_transforms)

# 划分数据集
train_size = int(0.8*len(total_data))   # 训练数据集占 80%  (900个)
test_size = len(total_data) - train_size  # 剩下的为测试集   (225个)

# 该方法将总体数据集total_data按照指定的大小比例[train_size, test_size]随机划分训练集和测试集。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])

6、加载数据集

# 加载数据集
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=1)

二、搭建简单的CNN网络(特征提取+分类)

这次搭建的CNN网络,里面使用了归一化nn.BatchNorm2d()

添加nn.BatchNorm2d()常见的时机和位置:

卷积层之后:在卷积层之后添加nn.BatchNorm2d()是非常常见的做法,有助于对卷积操作后的特征图进行归一化,从而加速训练并提高模型的稳定性

②全连接层之前

③残差连接中

# 搭建简单的CNN网络模型
class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)

        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)

        self.poo1 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1,padding=0)
        self.bn3 = nn.BatchNorm2d(24)

        self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.fc1 = nn.Linear(24*50*50,len(classNames))  # nn.Linear(in_features, out_features)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        x = x.view(-1, 24*50*50)
        x = self.fc1(x)
        return x


model = Network_bn().to(device)
print(model)

运行后得到的模型结果为:

三、训练模型

1、设置超参数

# 设置超参数
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
learn_rate = 1e-4  # 学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

2、编写训练函数

# 编写训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)  # 批次数目

    train_loss, train_acc = 0, 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()  # 反向传播
        optimizer.step() # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float()).sum().item()
        train_loss += loss.item()
    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

3、编写测试函数

# 编写测试函数
def test(dataloader, model, loss_fn):
    size=len(dataloader.dataset)
    num_batches = len(dataloader)

    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

4、正式训练

# 4、正式训练
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    model.eval()

    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}'
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,epoch_test_acc*100, epoch_test_loss))
print('Done')

5、可视化结果

# 四、可视化结果
warnings.filterwarnings('ignore')
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率

epochs_range = range(epochs)
plt.figure(figsize=(12,3))

plt.subplot(1, 2, 1)
plt.plot(epochs_range,train_acc, label='Training Accuracy')
plt.plot(epochs_range,test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

# 在服务器上运行,需要这个
plt.savefig("/data/xx/resultImg.jpg")  # 里面是服务器上存放,运行结果图片的地址
plt.show()

6、训练结果截图对比

 

7、模型保存

# 模型保存
PATH = './model.pth'
torch.save(model.state_dict(), PATH)

这周的内容暂时就到这里了,最主要的变化,就是搭建CNN网络的时候,在每个卷积层后都添加了归一化处理

新增了一个模型保存,将训练的模型保存下来,以便后续使用该模型,进行预测。

下周就可以使用模型进行预测啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1536928.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4、类加载器

2.4.1 什么是类加载器 类加载器(ClassLoader)是Java虚拟机提供给应用程序去实现获取类和接口字节码数据的技术,类加载器只参与加载过程中的字节码获取并加载到内存这一部分。 类加载器会通过二进制流的方式获取到字节码文件的内容&#xff0c…

Visual Studio配置libtorch(cuda安装一步到位)

Visual Studio配置libtorch visual Studio安装cuDNN安装CUDAToolkit安装libtorch下载Visual Studio配置libtorch(cuda版本配置) visual Studio安装 visual Studio点击安装 具体的安装和配置过程这里就不进行细讲了,可以参考我这篇博客Visual Studio配置OpenCV(保姆…

【嵌入式学习】Qtday03.21

一、思维导图 二、练习 自由发挥登录窗口的应用场景,实现一个登录窗口界面。(不要使用课堂上的图片和代码,自己发挥,有利于后面项目的完成) 要求: 1. 需要使用Ui界面文件进行界面设计 2. ui界面上的组件…

vue.js制作学习计划表案例

通俗易懂,完成“学习计划表”用于对学习计划进行管理,包括对学习计划进行添加、删除、修改等操作。 一. 初始页面效果展示 二.添加学习计划页面效果展示 三.修改学习计划完成状态的页面效果展示 四.删除学习计划 当学习计划处于“已完成”状态时&…

栈——数据结构——day4

栈的定义 栈是限定仅在一段进行插入和删除操作的线性表。 我们把允许插入和删除的一端称为栈顶(top),另一端称为栈底(bottom),不含任何数据元素的栈称为空栈。栈又称为后进先出(Last In First Out)的线性表,简称LIFO结构。 栈的插入操作,叫作进栈&#…

开源项目ChatGPT-Next-Web的容器化部署(三)-- k8s deployment.yaml部署

一、说在前面的话 有了docker镜像,要把一个项目部署到K8S里,主要就是编写deployment.yaml。 你需要考虑的是: 环境变量服务的健康检测持久化启动命令程序使用的数据源程序使用的配置文件 因为本前端项目比较简单,这里只做一个…

重学SpringBoot3-Profiles介绍

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-Profiles介绍 Profiles简介如何在Spring Boot中使用Profiles定义Profiles激活ProfilesIDEA设置active profile使用Profile-specific配置文件 条件化Bean…

Python爬虫案例-爬取主题图片(可以选择自己喜欢的主题)

2024年了,你需要网络资源不能还自己再慢慢找吧? 跟着博主一块学习如何利用爬虫获取资源,从茫茫大海中寻找那个她到再妹子群中找妹子,闭着眼睛都可以找到合适的那种。文章有完整示例代码,拿过来就可以用,欢迎…

就业班 第二阶段 2401--3.18 day1 初识mysql

初识: 1、关系型数据库mysql、mariadb、sqlite 二维关系模型 2、非关系型数据库 redis、memcached sql 四个部分 DDL 数据库定义语言 创建数据库,创建用户,创建表 DML 数据库操作语言 增删改 DQL 数据库查询语言 查 DCL 数据库控制语言 授权 …

Pake一键打包,轻松构建桌面级应用!

Pake:顷刻之间,智能封装——WEB到桌面瞬间联通,让网站应用像搭积木般部署 - 精选真开源,释放新价值。 概览 Pake,作为一款新颖且极具创新性的桌面应用开发框架,凭借其独特的技术路径和高效的实现方式&…

时代教育期刊投稿发表

《时代教育》是由成都传媒集团主管主办,中华人民共和国新闻出版总署批准国内公开出版发行的专业教育类期刊,主要刊登各类高等院校、职业技术学校、中小学教师及研究生、教育科研工作者的教育实践研究成果;教育教学行业的最新动态;…

基于SSM+Jsp+Mysql的KTV点歌系统

基于SSMJspMysql的KTV点歌系统 基于SSMJspMysql的KTV点歌系统的设计与实现 开发语言:Java框架:ssm技术:JSPJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工…

jvm提供的远程调试 简单使用

JVM自带远程调试功能 JVM远程调试,其实是两个虚拟机之间,通过socket通信,达到远程调试的目的; 前提 确保本地和远程的网络是开通的; 本地操作 远程操作 在启动命令参数中 把上面的内容复制进去

第 6 章 ROS-URDF练习(自学二刷笔记)

重要参考: 课程链接:https://www.bilibili.com/video/BV1Ci4y1L7ZZ 讲义链接:Introduction Autolabor-ROS机器人入门课程《ROS理论与实践》零基础教程 6.3.4 URDF练习 需求描述: 创建一个四轮圆柱状机器人模型,机器人参数如下,底盘为圆柱…

NIVision-相机图像采集

应用场景 上位机与工业相机通讯,控制相机抓取图像。 工业相机的通讯接口大多为USB口或网口。 USB口则直接将通讯线缆插入上位机USB端口,打开MAX中设备与接口一栏可以看到电脑给相机分配的资源名称;网口则需要将网线连接相机和上位机&#xf…

【数据库】SQL Server 2008 R2 安装过程

启动安装程序,点击setup,进入【SQLServer安装中心】 点击界面左侧的【安装】,然后点击右侧的【全新SQLServer独立安装或向现有安装添加功能】,进入【SQLServer2008R2安装程序】界面,如下图所示: 进入【安装…

浅谈Postman与Jmeter的区别、用法

前阶段做了一个小调查,发现软件测试行业做功能测试和接口测试的人相对比较多。在测试工作中,有高手,自然也会有小白,但有一点我们无法否认,就是每一个高手都是从小白开始的,所以今天我们就来谈谈一大部分人…

师徒互电,眼冒金星,采集系统变电刺激系统!

原文来自微信公众号:工程师看海,很高兴分享我的原创文章,喜欢和支持我的工程师,一定记得给我点赞、收藏、分享哟。 加微信[chunhou0820]与作者进群沟通交流 电的我眼冒金星,以为自己被三体召唤,整个世界为我…

预测一下,GPT-5 会在什么时候发布,又会有哪些更新?

发布预期:GPT-5预计将于11月发布,可能与ChatGPT发布两周年同期。竞争态势:谷歌的Gemini与GPT-4 turbo已展开竞争。逐步发布:GPT-5可能通过模型训练过程中的中间检查点逐步发布。训练与安全测试:实际训练可能需3个月&am…

【Java前端技术栈】Vue2、Vue Cli、Axio入门

一、基本介绍 1.Vue 是什么? Vue (读音 /vjuː/,类似于 view) 是一个前端框架, 易于构建用户界面 2. Vue 的核心库只关注视图层,不仅易于上手,还便于与第三方库或项目整合 3. 支持和其它类库结合使用 4. 开发复杂的单页应用非常方便 5.…