pytorch-天气识别

news2024/10/5 19:19:24
  •  🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章地址: 365天深度学习训练营-第P3周:天气识别
  • 🍖 作者:K同学啊

一、前期准备

1.设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms,datasets
import matplotlib.pyplot as plt
import os,PIL,pathlib
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

2.导入数据

data_dir = './weather_photos/'
data_dir = pathlib.Path(data_dir)

data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('\\')[1] for path in data_paths]
classNames
['cloudy', 'rain', 'shine', 'sunrise']
train_transforms = transforms.Compose([
    transforms.Resize([224,224]),# resize输入图片
    transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换成tensor
    transforms.Normalize(
        mean = [0.485, 0.456, 0.406],
        std = [0.229,0.224,0.225]) # 从数据集中随机抽样计算得到
])

total_data = datasets.ImageFolder(data_dir,transform=train_transforms)
total_data
Dataset ImageFolder
    Number of datapoints: 1125
    Root location: weather_photos
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=PIL.Image.BILINEAR)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

3.数据集划分

train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_size,test_size
(900, 225)
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
(<torch.utils.data.dataset.Subset at 0x246934b8df0>,
 <torch.utils.data.dataset.Subset at 0x246934b82b0>)
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
for X,y in test_dl:
    print('Shape of X [N, C, H, W]:', X.shape)
    print('Shape of y:', y.shape)
    break
Shape of X [N, C, H, W]: torch.Size([32, 3, 224, 224])
Shape of y: torch.Size([32])

二、构建简单的CNN网络

import torch.nn.functional as F

num_classes = 4  # 图片的类别数

class Network_bn(nn.Module):
     def __init__(self):
        super().__init__()
         # 特征提取网络
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0) 
        self.bn1 = nn.BatchNorm2d(12)                
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0) 
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(24) 
        self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)  
        # 分类网络
        self.fc1 = nn.Linear(24*50*50,num_classes)
     # 前向传播
     def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        x = x.view(-1,24*50*50)
        x = self.fc1(x)
       
        return x
    
model = Network_bn().to(device)
model
Network_bn(
  (conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=60000, out_features=4, bias=True)
)

 

三、训练模型

1.设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

opt
SGD (
Parameter Group 0
    dampening: 0
    lr: 0.0001
    momentum: 0
    nesterov: False
    weight_decay: 0
)

2.编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共900张图片
    num_batches = len(dataloader)   # 批次数目,29(900/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

3.编写测试函数

与测试函数和训练函数大致相同,由于不需要进行梯度下降更新权重,所以不需要传入优化器。

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)          # 批次数目,8(255/32=8,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss

4、正式训练

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
Epoch: 1, Train_acc:61.3%, Train_loss:0.975, Test_acc:60.9%,Test_loss:0.961
...
Epoch:18, Train_acc:94.4%, Train_loss:0.255, Test_acc:87.6%,Test_loss:0.315
Epoch:19, Train_acc:93.8%, Train_loss:0.231, Test_acc:92.4%,Test_loss:0.226
Epoch:20, Train_acc:94.9%, Train_loss:0.187, Test_acc:92.0%,Test_loss:0.315
Done

四、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/74039.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MAC苹果系统安装数字证书的方法

MAC苹果系统安装数字证书的方法之工具/原料 Mac OS电脑一台 数字证书 先讲讲安装方法,mac系统默认浏览器是Safari,那小D在这里就以Safari浏览器为例子,讲解一下相关的安装方法 如果已有开通了数字证书的用户,在重装了系统或是在没有安装安装证书的电脑上进行付款时,会提…

【ELM回归预测】基于非洲秃鹫算法优化极限学习机预测附matlab代码

✅作者简介&#xff1a;热爱科研的Matlab仿真开发者&#xff0c;修心和技术同步精进&#xff0c;matlab项目合作可私信。 &#x1f34e;个人主页&#xff1a;Matlab科研工作室 &#x1f34a;个人信条&#xff1a;格物致知。 更多Matlab仿真内容点击&#x1f447; 智能优化算法 …

FL Studio水果21版本助力原创音乐人(中文完整版All Plugins)

最近&#xff0c;网上算是“风言风语”吧&#xff0c;关于FL Studio是否出21版的说法各异。首先呢&#xff0c;这里先肯定一点&#xff0c;FL Studio即将出FL Studio 21版本&#xff0c;但是正式版已经出来。希望大家不要被网上一些所谓冒充发布的FL Studio21正式版所骗&#x…

信息系统分析与设计:摊位管理信息系统

摊位管理信息系统的分析与设计 1 市场分析 1.1 地摊经济发展背景 1.2 地摊经济逐渐复苏 1.3 地摊经济的放管服 2 目标市场定位 2.1 普通城市居民 2.2 政府相关管理部门 3 系统主要介绍 3.1 系统创新描述 3.2 主要搭建流程 3.3 主要业务模块 3.4 业务流程图 3.5 组…

Vue学习笔记--第二章(尚硅谷学习视频总结)

第二章 Vue组件化编程第二章 Vue组件化编程2.1. 模块与组件、模块化与组件化2.1.1. 模块2.1.2. 组件2.1.3. 模块化2.1.4. 组件化2.2. 非单文件组件2.2.1. 基本使用2.2.2. 组件注意事项2.2.3. 组件的嵌套2.2.4. VueComponent2.2.5. 一个重要的内置关系2.3. 单文件组件第二章 Vue…

【C#基础学习】第十七章、数组

目录 数组 1.数组的类型 1.1 一维数组和矩形数组 1.1.1实例化一维数组和矩形数组 1.2 访问数组元素 1.3 初始化数组 1.3.1 显式初始化一维数组 1.3.2 显式初始化矩形数组 1.3.3 显式初始化的快捷语法 1.3.4 隐式类型数组 1.4 交错数组 1.4.1 声明交错数组 1.4.2 实例…

bump map(凹凸贴图)的一个简单生成方法

用于渲染物体表面&#xff0c;增加真实感的bump map(凹凸贴图)的一个简单生成方法。 1. 在 Perlin Noise Map Generator - OpenProcessing 生成一个perlin noise map&#xff0c; 点击代码按钮&#xff0c;修改生成图像的分辨率 点击 paly 按钮&#xff0c;设置参数&#xf…

学习笔记-3-SVM-10-SVR

细节内容请关注微信公众号&#xff1a;运筹优化与数据科学 ID: pomelo_tree_opt outline 1. Linear regression 2. Support vector regression 3. SVR vs. SVM 4. Linear SVR 5. Kernel SVR ------------------------------------ 1. Linear regression OR里最常使用的…

【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】

基于之前学习的内容&#xff0c;让我们动手实战一个Kaggle比赛的&#xff1a;房价预测实战案例。Kaggle是一个著名的供机器学习爱好者交流的平台&#xff0c;该房价预测实战网址&#xff1a;https://www.kaggle.com/competitions/house-prices-advanced-regression-techniques …

浅析Linux 内存布局

【推荐文章】 路由选择协议——RIP协议 纯干货&#xff0c;linux内存管理-内存管理架构&#xff08;建议收藏&#xff09; 轻松学会linux下查看内存频率,内核函数,cpu频率 X86体系结构 在X86体系结构下&#xff0c;物理内存地址一般从0x0000_0000开始&#xff0c;而Linux内核主…

微信小程序实战之获取用户信息并保存唯一实例

前言 这是我参加掘金启航计划的第二篇文章&#xff0c;这次总结的是获取用户信息并联合 mobx 状态管理库&#xff0c;保存全局唯一的用户对象。 本篇文章基于 微信云开发 &#xff0c;数据从云数据库中取出&#xff0c;使用微信云数据库API进行获取数据&#xff0c;希望观众老…

Altium Designer飞线不从过孔里面出线如何解决?

出现以上飞线不从过孔出线的原因是其拓扑结构所导致&#xff0c;解决方式就是设置下拓扑结构。 1、执行菜单栏命令“设计-规则”&#xff0c;或者快捷键DR&#xff0c;快速打开“PCB规则及约束编辑器”对话框&#xff0c;如图1所示。 2、在对应的对话框中&#xff0c;选择“Rou…

postgres源码解析41 btree索引文件的创建--2

本文将从btbuild函数作为入口从源码角度进行讲解btree文件的创建流程&#xff0c;执行SQL对应为CREATE TABLE wp_shy(id int primary key, name carchar(20))。知识回顾见&#xff1a;postgres源码解析41 btree索引文件的创建–1 执行流程图梳理 _bt_spools_heapscan 执行流程…

2153年,人类已被AI所奴役。就在这一天,作为一名被俘虏的“搜查部队”士兵,你来到了A0007号城外的反抗军基地中

2153年&#xff0c;地球。   人类&#xff0c;已被AI所奴役。   这个AI的缩写名为——PTA&#xff0c;或称“辟塔”。      辟塔的原型&#xff0c;是一个用于分析网络用户消费倾向并立即给出相关引导的软广告程序。   很快&#xff0c;辟塔便成了广大商家的宠儿&…

【华为上机真题 2022】求解连续数列

&#x1f388; 作者&#xff1a;Linux猿 &#x1f388; 简介&#xff1a;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我&#xff0c;关注我&#xff0c;有问题私聊&#xff01; &…

MatLab SimuLink国产代替

MATLab SimuLink国产代替 米国的限制&#xff0c;把工业软件的国产化推到风口浪尖&#xff0c;作为扎根工业软件开源基础架构20多年的UCanCode, 一直是国外顶尖工业软件的基础架构提供商之一。许多国外软件都在这个基础上构建出来&#xff0c;这里我们也希望探讨一下国产代替Ma…

乐享元游的 UWA Pipeline 最佳实践分享

“躬身入局 践行游戏研发工业化”是UWA在2022年研发上坚持的方向&#xff0c;其中UWA Pipeline更是今年在工业化部署上的一个重要的突破口。在近一年里&#xff0c;越来越多的游戏研发团队在日常项目生产开发中&#xff0c;使用UWA Pipeline搭建了符合自身需求的DevOps研发交付…

fat32文件系统分析

fat32文件系统结构&#xff1a; fat32文件系统比fat16文件系统少了根目录FDT&#xff0c;其实是将根目录归结到数据区中了。 注意数据区第一个扇区所在蔟为2号蔟。 首先在磁盘管理中创建一个fat32磁盘&#xff1a; 大小为16GB。 使用winhex打开磁盘。 可以看到MBR在扇区0处…

AI推理卡/tensorRT c++

#####AI 推理卡&#xff1a;我的需求是x86上Nvidia显卡训练好的模型 用在AI推理卡上进行推理### AI 推理卡 环境配置 安装ubuntu系统、AI推理卡环境 1&#xff0c;安装ubuntu20.04.4 过程忽略&#xff0c;网上教程很多。 2&#xff0c;ubuntu20.04.4设置root登录&#xf…

入门系列 - Git基本操作

本篇文章&#xff0c;是基于我自用Linux系统中的自定义文件夹“test_rep”&#xff0c;当做示例演示 具体Git仓库的目录在&#xff1a;/usr/local/git/test_rep Git基本操作 之前我们已经创建了 Git 版本库了&#xff0c;下一步我们将进行一些 Git 的基本操作。 有关 Git 版本…