【深度学习】四种天气分类 模版函数 从0到1手敲版本

news2024/10/6 12:27:21

引入该引入的库

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import torch.optim as optim
%matplotlib inline
import os
import shutil
import glob
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

注意:os.environ[“KMP_DUPLICATE_LIB_OK”]=“TRUE” 必须要引入否则用plt出错

数据集整理

img_dir = r"F:\播放器\1、pytorch全套入门与实战项目\课程资料\参考代码和部分数据集\参考代码\参考代码\29-42节参考代码和数据集\四种天气图片数据集\dataset2"
base_dir = r"./dataset/4weather"

img_list = glob.glob(img_dir+"/*.*")
test_dir = "test"
train_dir = "train"
species = ["cloudy","rain","shine","sunrise"]
for idx,img_path in enumerate(img_list):
    _,img_name = os.path.split(img_path)
    if idx%5==0:

        for specie in species:
            if img_path.find(specie) > -1:
                dst_dir = os.path.join(test_dir,specie)
                os.makedirs(dst_dir,exist_ok=True)
                dst_path = os.path.join(dst_dir,img_name)
    else:
        
        for specie in species:
            if img_path.find(specie) > -1:
                dst_dir = os.path.join(train_dir,specie)
                os.makedirs(dst_dir,exist_ok=True)
                dst_path = os.path.join(dst_dir,img_name)
    shutil.copy(img_path,dst_path)

生成测试和训练的文件夹,
目录结构如下:
在这里插入图片描述
rain 下面就是图片了
在这里插入图片描述

构建ds和dl

from torchvision import transforms
transform = transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
train_ds=torchvision.datasets.ImageFolder(train_dir,transform)
test_ds = torchvision.datasets.ImageFolder(train_dir,transform)

在这里插入图片描述
在这里插入图片描述
一张图片效果,这是rain图片 这里需要转换维度,把channel放到最后。同时把数据拉到0-1之间,原本std 和mean 【0.5,0,5】数据在-0.5~0.5之间
在这里插入图片描述
类的映射
在这里插入图片描述

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):
    img = (img.permute(1, 2, 0).numpy() + 1)/2
    plt.subplot(2, 3, i+1)
    plt.title(id_to_class.get(label.item()))
    plt.imshow(img)

这个方法要学会
在这里插入图片描述

定义网络

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.conv2 = nn.Conv2d(16,32,3)
        self.conv3 = nn.Conv2d(32,64,3)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout = nn.Dropout(0.3)
        self.fc1 = nn.Linear(64*10*10,1024)
        self.fc2 = nn.Linear(1024,4)
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        # print(x.size()) 这里是可以计算出来的,需要掌握计算方法
        x = x.view(-1,64*10*10)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)
model = Net()        
preds = model(imgs)
preds.shape, preds

在这里插入图片描述
定义损失函数和优化函数:

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(),lr=0.001)

定义网络

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    for x, y in trainloader:
        if torch.cuda.is_available():
            x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
        
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / total
        
        
    test_correct = 0
    test_total = 0
    test_running_loss = 0 
    
    with torch.no_grad():
        for x, y in testloader:
            if torch.cuda.is_available():
                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / test_total
    
        
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
             )
        
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

训练:

epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
epoch:  0 loss:  0.043 accuracy: 0.714 test_loss:  0.029 test_accuracy: 0.809
epoch:  1 loss:  0.03 accuracy: 0.807 test_loss:  0.023 test_accuracy: 0.867
epoch:  2 loss:  0.024 accuracy: 0.857 test_loss:  0.018 test_accuracy: 0.888
epoch:  3 loss:  0.021 accuracy: 0.869 test_loss:  0.017 test_accuracy: 0.894
epoch:  4 loss:  0.018 accuracy: 0.886 test_loss:  0.014 test_accuracy: 0.921
epoch:  5 loss:  0.017 accuracy: 0.897 test_loss:  0.022 test_accuracy: 0.869
epoch:  6 loss:  0.013 accuracy: 0.923 test_loss:  0.008 test_accuracy: 0.944
epoch:  7 loss:  0.009 accuracy: 0.947 test_loss:  0.011 test_accuracy: 0.924
epoch:  8 loss:  0.006 accuracy: 0.966 test_loss:  0.004 test_accuracy: 0.988
epoch:  9 loss:  0.004 accuracy: 0.979 test_loss:  0.002 test_accuracy: 0.998
epoch:  10 loss:  0.004 accuracy: 0.979 test_loss:  0.005 test_accuracy: 0.966

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
比较重要的点,
1.分类的数据集布局要记住
2.图片经过conv2 多次后的值要会算 todo
3.图片展示的方法要会

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1540231.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据挖掘】实验5:数据预处理(1)

实验5:数据预处理(1) 一:实验目的与要求 1:熟悉和掌握数据预处理,学习数据清洗、数据集成、数据变换、数据规约、R语言中主要数据预处理函数。 二:实验内容 【缺失值分析】 第一步&#xff1…

Python工具-清理Unity(批量深度)清理U3D项目工程保留关键工程文件

前沿 1. Unity工程越来越多,很久不用的工程里存在了很多无用的大文件夹,极大的影响电脑容量。 2. 我电脑里面U3D工程只有17个,但容量就高达60GB,使用自己编写的工具清理后,减到了30GB多。清理了不是很重要的文件和文件…

2024-03-22 问AI: 介绍一下深度学习中的 sigmoid函数

文心一言 Sigmoid函数是深度学习中常用的一种激活函数,主要用于将神经元的输出值压缩到0和1之间。这种特性使得sigmoid函数在二分类问题中特别有用,因为输出值可以被解释为属于某个类别的概率。 sigmoid函数的数学表达式为: σ(x) 1 / (1…

Notepad++ 如何调整显示字面大小

在 Notepad 上,可以使用 ctrl 加上鼠标的左键来滚动来进行调整。 如何恢复默 可以使用 Ctrl 加数字键盘上的 / 键 来恢复默认设置。 当然也可以通过菜单栏上 view 菜单下的 Zoom 选项。 上面的界面中可以看到我们的在 Notepad 中使用的选项。 Notepad 如何调整显示…

llvm后端

SelectionDAGBuilder是LLVM(Low Level Virtual Machine)编译器中的一个重要组件,它负责将LLVM中间表示(Intermediate Representation,IR)转换为SelectionDAG(选择有向无环图)的形式。…

RabbitMq高可用

消息队列高级 服务异步通信-高级篇1.消息可靠性1.1.生产者消息确认1.2.消息持久化1.3.消费者消息确认1.4.消费失败重试机制1.5.总结 2.死信交换机2.1.初识死信交换机2.2.TTL2.3.延迟队列 3.惰性队列3.1.消息堆积问题3.2.惰性队列 4.MQ集群4.1.集群分类4.2.普通集群4.3.镜像集群…

C#,图论与图算法,计算图(Graph)的岛(Island)数量的算法与源程序

1 孤岛数 给定一个布尔矩阵,求孤岛数。一组相连的1形成一个岛。例如,下面的矩阵包含5个岛: 在讨论问题之前,让我们先了解什么是连接组件。无向图的连通分量是一个子图,其中每两个顶点通过一条路径相互连接,并且不与子图外的其他顶点连接。 所有顶点相互连接的图只有一个…

Spring05 SpringIOC DI

名词解释 今天我们来介绍Spring框架的最重要的part之一 SpringIOC 和 DI 这里的SpringIOC 其实是容器的意思,Spring是一个包含了很多工具方法的IOC容器 什么是IOC呢? IOC其实是Spring的核心思想 Inversion of Control (控制反转) 可能这里你还是不理解这个是啥意思 其实就…

xilinx的高速接口构成原理和连接结构

本文来源: V3学院 尤老师的培训班笔记【高速收发器】xilinx高速收发器学习记录Xilinx-7Series-FPGA高速收发器使用学习—概述与参考时钟GT Transceiver的总体架构梳理 文章目录 一、概述:二、高速收发器结构:2.1 QUAD2.1.1 时钟2.1.2 CHANNEL…

【SysBench】OLTP 基准测试示例

前言 本文采用 MySQL 沙盒实例作为测试目标,使用 sysbench-1.20 对其做 OLTP 基准测试。 有关 MySQL 沙盒的更多信息,请参阅 玩转 MySQL Shell 沙盒实例,【MySQL Shell】6.8 AdminAPI MySQL 沙盒 。 1、部署一个 MySQL 沙盒实例 使用 mysq…

【ESP32S3 Sense接入百度在线语音识别】

视频地址: 1. 前言 使用Seeed XIAO ESP32S3 Sense开发板接入百度智能云实现在线语音识别。自带麦克风模块用做语音输入,通过串口发送字符“1”来控制数据的采集和上传。 步骤概括    (1) 在百度云控制端选择“语音识别”并创建应用获取API Key和Secr…

MapReduce学习问题记录

1、如何跳过对某行数据的处理 第一行数据是字段名不需要处理,我们知道第一行偏移量是0(行记录的时候是从数组首地址开始,到了行标识符进行一次计数,这个计数就是行偏移量,从0开始),我们根据偏移…

银行5G短消息应用架构设计

(一)RCS简介 1.1 RCS的提出与标准制定 RCS(Rich Communication Services & Suite,富媒体通信)是GSMA(Groupe Speciale Mobile Association,全球移动通信系统协会)在2008年提出的一种通讯方式,RCS融合了语音、消息…

【算法每日一练]-图论(保姆级教程篇16 树的重心 树的直径)#树的直径 #会议 #医院设置

目录 树的直径 题目:树的直径 (两种解法) 做法一: 做法二: 树的重心: 题目: 会议 思路: 题目:医院设置 思路: 树的直径 定义:树中距离最…

android.os.TransactionTooLargeException解决方案,Kotlin

android.os.TransactionTooLargeException解决方案,Kotlin 首先,特意制造一个让Android发生TransactionTooLargeException的场景,一个Activity启动另外一个Activity,在Intent的Bundle里面塞入一个大的ArrayList: import android.…

阿里云OSS存储的视频如何加水印

OSS是不能进行视频添加水印的,可以图片添加水印。 您可以在视频点播中进行配置: https://help.aliyun.com/zh/vod/user-guide/video-watermarks?spma2c4g.11186623.0.i2 原来的业务代码都是使用python 对oss的 视频进行上传 的,上传的视频路径已经保存到…

设计数据库之外部模式:数据库的应用

Chapter5:设计数据库之外部模式:数据库的应用 笔记来源:《漫画数据库》—科学出版社 设计数据库的步骤: 概念模式 概念模式(conceptual schema)是指将现实世界模型化的阶段进而,是确定数据库理论结构的阶段。 概念模…

系统架构设计-构建系统应用

1. 系统架构目标与设计原则 在设计系统架构时,我们的目标是确保系统具有以下特点: 可靠性:系统能够持续稳定运行,保证业务可用性。可伸缩性:系统能够根据负载变化自动扩展或收缩,以应对不同的流量需求。容…

【Java高级】利用反射机制获取类的所有信息

文章目录 1.相关准备2.导航图3. 相关的方法----------------------------------------------类------------------------------------------------------1 类的修饰符2 类名 -----------------------------------------------属性--------------------------------------------…

FileZilla 链接服务器提示 20 秒连接超时

FileZilla 有个默认设置是如果 20 秒没有数据的话会自动中断链接。 Command: Pass: **************** Error: Connection timed out after 20 seconds of inactivity Error: Could not connect to server修改配置 这个配置是可以修改的,修改的步骤为: …