深度学习pytorch实战第P3周--实现天气识别

news2025/1/12 1:01:35
>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

引言

1.复习上周

深度学习pytorch实战-第P2周-彩色图片识别icon-default.png?t=N7T8http://t.csdnimg.cn/f5l6F对于上周的学习,数据集是下载的

2.摆正心态

正如K同学所说,学到第三周左右就会有点感觉了,还真是这样,引领我入门,激发了我的兴趣,同时感谢同济子豪兄。

3.本机环境

见上文

4.学习目标

扎扎实实学好习,走好每一步。

一、前期准备

1.设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets

import os,PIL,pathlib,random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

2.导入数据

data_dir = './weather_photos/'
print(data_dir)
data_dir = pathlib.Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames
  • 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象。
  • 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
  • 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classeNames
  • 第四步:打印classeNames列表,显示每个文件所属的类别名称。

3.数据可视化

import matplotlib.pyplot as plt
from PIL import Image

# 指定图像文件夹路径
image_folder = './weather_photos/cloudy/'

# 获取文件夹中的所有图像文件
image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))]

# 创建Matplotlib图像
fig, axes = plt.subplots(3, 8, figsize=(16, 6))

# 使用列表推导式加载和显示图像
for ax, img_file in zip(axes.flat, image_files):
    img_path = os.path.join(image_folder, img_file)
    img = Image.open(img_path)
    ax.imshow(img)
    ax.axis('off')

# 显示图像
plt.tight_layout()
plt.show()

 

 图片大小处理

total_datadir = './data/'

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_data

划分数据集(4:1)

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_dataset

 

batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=1)
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

 

二、构建简单的CNN网络

对于一般的CNN网络来说,都是由特征提取网络和分类网络构成,其中特征提取网络用于提取图片的特征,分类网络用于将图片进行分类

参数详解见上文。

import torch.nn.functional as F

class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        """
        nn.Conv2d()函数:
        第一个参数(in_channels)是输入的channel数量
        第二个参数(out_channels)是输出的channel数量
        第三个参数(kernel_size)是卷积核大小
        第四个参数(stride)是步长,默认为1
        第五个参数(padding)是填充大小,默认为0
        """
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn5 = nn.BatchNorm2d(24)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(24*50*50, len(classeNames))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))      
        x = F.relu(self.bn2(self.conv2(x)))     
        x = self.pool1(x)                        
        x = F.relu(self.bn4(self.conv4(x)))     
        x = F.relu(self.bn5(self.conv5(x)))  
        x = self.pool2(x)                        
        x = x.view(-1, 24*50*50)
        x = self.fc1(x)

        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = Network_bn().to(device)
model

三、训练模型

1.设置超参数

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)

2.编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)   # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

 

3.编写测试函数

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss

4.正式训练

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

 也不是伦茨越多越好

四、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

五、保存模型并预测

# 1.保存模型

# torch.save(model, 'model.pth') # 保存整个模型
torch.save(model.state_dict(), 'model_state_dict.pth') # 仅保存状态字典

# 2. 加载模型 or 新建模型加载状态字典

# model2 = torch.load('model.pth') 
# model2 = model2.to(device) # 理论上在哪里保存模型,加载模型也会优先在哪里,但是指定一下确保不会出错

model2 = Network_bn().to(device) # 重新定义模型
model2.load_state_dict(torch.load('model_state_dict.pth')) # 加载状态字典到模型

# 3.图片预处理
from PIL import Image
import torchvision.transforms as transforms

# 输入图片预处理
def preprocess_image(image_path):
    image = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # 假设使用的是224x224的输入
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image).unsqueeze(0)  # 增加一个批次维度
    return image

# 4.预测函数(指定路径)
def predict(image_path, model):
    model.eval()  # 将模型设置为评估模式
    with torch.no_grad():  # 关闭梯度计算
        image = preprocess_image(image_path)
        image = image.to(device)  # 确保图片在正确的设备上
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)  # 获取最可能的预测类别
        return predicted.item()

# 5.预测并输出结果
image_path = "./weather_photos/shine/shine101.jpg"  # 替换为你的图片路径
prediction = predict(image_path, model)
class_names = ["cloudy", "rain", "shine", "sunrise"]  # Replace with your class labels
predicted_label = class_names[prediction]
print("Predicted class:", predicted_label)

 

Predicted class: shine
# 选取dataloader中的一个图像进行判断
import numpy as np
# 选取图像
imgs,labels = next(iter(train_dl))
image,label = imgs[0],labels[0]

# 选取指定图像并展示
# 调整维度为 [224, 224, 3]
image_to_show = image.numpy().transpose((1, 2, 0))

# 归一化
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image_to_show = std * image_to_show + mean
image_to_show = np.clip(image_to_show, 0, 1)

# 显示图像
plt.imshow(image_to_show)
plt.show()

# 将图像转移到模型所在的设备上(如果使用GPU)
image = image.to(device)

# 预测
with torch.no_grad():
    output = model(image.unsqueeze(0))  # 添加批次维度
    
# 输出预测结果
_, predicted = torch.max(output, 1)
class_names = ["cloudy", "rain", "shine", "sunrise"]  # Replace with your class labels
predicted_label = class_names[predicted]
print(f"Predicted: {predicted.item()}, Actual: {label.item()}")

 如何保存模型可以看这篇icon-default.png?t=N7T8http://t.csdnimg.cn/yCkr7

六,总结

构建数据集中

此次数据集不是直接从网上下载的,而是保存在本地的,那么我们构建数据集的代码就有3步了,第一步就是从我们的本地获取数据集,使用pathlib将路径转化为pathlib.path对象,使用glob方法获取路径下的所有文件路径并保存到datapath,通过spilt分割路径并保存在classname中,打印classname列表。这是在确保路径文件夹正确,且分类正确。接下来可以使用matplotlib.pyplot和PIL来获取图片,并可视化,过程中使用列表推导式加载和显示图像,第二步划分数据集,4:1比例划分并保存为dataset对象,第三步定义batch_size大小,使用dataloder加载器来管理数据,

构建cnn网络中

主要是特征提取网络和分类网络的构建,首先进行一个网络初始化init,定义网络结构,卷积层1,bn层1,(nn.BatchNorm2d(12) 是 PyTorch 中的一个批量归一化层。批量归一化用于加速神经网络的训练过程,并提高模型的泛化能力。在卷积神经网络中),卷积2,bn2,池化1,卷积4,bn层4,卷积5,bn层5,池化层2,全连接层1,然后定义进行前向传播结构,并在每一层(卷积+bn)进行relu函数返回给x,然后再池化,再relu,最后拉平,使用view函数,并传给fc层进行分类。

训练模型中

设置超参数,loss一般用交叉熵,学习率一般0.0001,opt采用sgd或者adam

编写训练函数中

编写训练函数时候,要使用size知道dataloaer加载的dataset大小长度,每batchnum就是dataloder的长度,定义训练集损失和准确率,获取图片和标签从dataloder中,计算误差,进行啊反向传播更新参数,每一步自动更新,记录acc和loss,最后计算train的acc和loss

编写测试函数同理

同理,再取图像之前加个当不进行训练时,停止梯度更新就行

来到了正式训练

定义epoch轮次,训练和特使的acc和loss定义,便于后续存储,轮次循环,model.train和model.eval,然后把每一批次的准确率loss都存在列表里,最后输出

最后可视化

定义plt窗口,可视化列表train_acc等等

保存模型

参上,现在我还有点不懂,慢慢学吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1596205.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pbootcms百度推广链接打不开显示404错误页面

PbootCMS官方在2023年4月21日的版本更新中(对应V3.2.5版本),对URL参数添加了如下判断 if(stripos(URL,?) ! false && stripos(URL,/?tag) false && stripos(URL,/?page) false && stripos(URL,/?ext_) false…

[MySQL]数据库原理8——喵喵期末不挂科

希望你开心,希望你健康,希望你幸福,希望你点赞! 最后的最后,关注喵,关注喵,关注喵,大大会看到更多有趣的博客哦!!! 喵喵喵,你对我真的…

Vulnhub靶机 DC-2渗透详细过程

VulnHub靶机 DC-2 打靶 目录 VulnHub靶机 DC-2 打靶一、将靶机导入到虚拟机当中二、攻击方式主机发现端口扫描服务探针爆破目录web渗透信息收集扫描探针登录密码爆破SSH远程登录rbash提权 一、将靶机导入到虚拟机当中 靶机地址: https://www.vulnhub.com/entry/dc…

51单片机入门_江协科技_27~28_OB记录的自学笔记_AT24C02数据存储秒表

27. AT24C02(I2C总线) 27.1. 存储器介绍 27.2. 存储器简化模型介绍,存储原理 27.3. AT24C02介绍 •AT24C02是一种可以实现掉电不丢失的存储器,可用于保存单片机运行时想要永久保存的数据信息 •存储介质:E2PROM •通讯接口:I2…

如何在Linux系统部署Joplin笔记并结合内网穿透实现无公网IP远程访问

文章目录 1. 安装Docker2. 自建Joplin服务器3. 搭建Joplin Sever4. 安装cpolar内网穿透5. 创建远程连接的固定公网地址 Joplin 是一个开源的笔记工具,拥有 Windows/macOS/Linux/iOS/Android/Terminal 版本的客户端。多端同步功能是笔记工具最重要的功能,…

极大似然估计、最大后验估计、贝叶斯估计

机器学习笔记 第一章 机器学习简介 第二章 感知机 第三章 支持向量机 第四章 朴素贝叶斯分类器 第五章 Logistic回归 第六章 线性回归和岭回归 第七章 多层感知机与反向传播【Python实例】 第八章 主成分分析【PCA降维】 第九章 隐马尔可夫模型 第十章 奇异值分解 第十一章 熵…

NzN的数据结构--外排序

接上文,本篇向大家简单展示一下外排序的实现。先三连后看才是好习惯!!! 在我们刚接触数据结构的时间里,我们只需要对外排序简单了解一下即可,重点要掌握的还是前面我们介绍的比较排序和非比较排序里的计数排…

基于java的社区生活超市管理系统

开发语言:Java 框架:ssm 技术:JSP JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7(一定要5.7版本) 数据库工具:Navicat11 开发软件:eclipse/myeclip…

爱比对软件:您的文本比对专家

在处理文本时,无论是学术研究、法律审查还是软件开发,精确的文本比对都至关重要。爱比对软件通过先进的技术,为您提供了一系列核心功能,旨在优化您的工作流程和数据管理,确保您的文档内容准确无误。 爱比对软件&#…

白盒测试之语句覆盖与分支(判定)覆盖

白盒测试之语句覆盖与分支(判定)覆盖(蓝桥云学习笔记) 1、语句覆盖 实验介绍 白盒测试的目的是通过检查软件内部的逻辑结构,对软件中的逻辑路径进行覆盖测试。控制流分析是白盒测试中的一种重要测试方法,…

实时传输,弹性优先——物联网通讯打造数据上传新标杆

随着信息技术的飞速发展,物联网技术已经成为连接物理世界和数字世界的桥梁。在物联网领域,数据上传的速度、稳定性和灵活性是评价通讯技术优劣的重要指标。近年来,物联网通讯在实时传输、弹性优先方面取得了显著进展,为数据上传树…

第42篇:随机存取存储器(RAM)模块<一>

Q:本期开始我们分期介绍随机存取存储器(RAM)模块及其设计实现方法。 A:随机存储器RAM,即工作时可以随时从一个指定地址读出数据,也可以随时将数据写入一个指定的存储单元。 DE2-115开发板上的Cyclone IV …

java-通过maven导入本地jar包常见的两种方式

一、准备工作 1.1 写一个小demo,将其打包 1.2 再新建一个项目,并在项目文件夹里新建一个lib文件夹,将上个jar包放进去。 二、方法一(重要) 找到那个小demo 的pom 文件将其中的三个信息拷贝到新项目中去 接着 调用demo…

【计算机系统结构】流水方式(续)

📝本文介绍 本文总结了流水机器的相关处理以及非线性流水线的调度 👋作者简介:一个正在积极探索的本科生 📱联系方式:943641266(QQ) 🚪Github地址:https://github.com/sankexilianhua &#x1f…

2A大电流线性稳压器具备两种输出电压范围

概述 PCD3931 是一款低噪声、低压差线性稳压器 (LDO),可提供 2A 输出电流,最大压降仅为 160mV。该器件提供两种输出电压范围。 PCD3931 的输出电压可通过外部电阻分压器在 0.5V 至 5.5V 范围内进行调节。PCD3931 集低噪声、高 PSRR 和高输出电流能力等特…

国密证书VS国外证书:选哪种更靠谱?一文带你了解优劣势!

在数字化生活的今天,无论是网上购物、在线支付还是远程办公,我们的信息安全和隐私保护都显得尤为重要。而数字证书,就像是一把“信任钥匙”,帮助我们确认信息的真实性和保护数据的安全。但是,面对国密证书和国外品牌证…

Java | Leetcode Java题解之第22题括号生成

题目&#xff1a; 题解&#xff1a; class Solution {static List<String> res new ArrayList<String>(); //记录答案 public List<String> generateParenthesis(int n) {res.clear();dfs(n, 0, 0, "");return res;}public void dfs(int n ,int…

【派兹互连-SailWind】这家公司悄然入局,国产EDA突围又有新看头了!

从光刻机到EDA软件&#xff0c; 国产厂商何以突围&#xff1f; 两年前&#xff0c;美发布禁令直接把对中国大陆半导体产业的限制&#xff0c;从光刻机扩大到集成电路所必需的EDA软件领域&#xff0c;在此之前华为因被美国列入实体清单&#xff0c;被三大海外EDA巨头断供&…

MOM系统:制造企业的“神级助手“!

一、大环境下的智能化改造 嘿&#xff0c;亲爱的制造企业老板们&#xff0c;你们是否曾经为生产计划混乱、物料和设备管理无序、产品质量不稳定等等问题而头疼不已&#xff1f;现在&#xff0c;有一个超级助手可以帮助你们解决这些问题&#xff0c;那就是MOM系统&#xff01;什…

《前端面试题》- JS基础 - 伪数组

第一次听说伪数组这个概念&#xff0c;听到的时候还以为是说CSS的伪类呢&#xff0c;网上一查&#xff0c;这东西原来还是个很常见的家伙。 何为伪数组 伪数组有两个特点&#xff1a; 具有length属性&#xff0c;其他属性&#xff08;索引&#xff09;为非负整数但是却不具备…