深度學習筆記14-CIFAR10彩色圖片識別(Pytorch)

news2024/11/16 5:28:36
  • 🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客
  • 🍖 原作者:K同学啊 | 接輔導、項目定制

一、我的環境

  • 電腦系統:Windows 10

  • 顯卡:NVIDIA GeForce GTX 1060 6GB

  • 語言環境:Python 3.7.0

  • 開發工具:Sublime Text,Command Line(CMD)

  • 深度學習環境:1.12.1+cu113


二、準備套件

# PyTorch 的核心模組,包含了張量操作、自動微分、神經網絡構建、優化器等
import torch

# PyTorch 的神經網絡模組,包含了各種神經網絡層和相關操作的類別和函數
import torch.nn as nn

# Matplotlib 的繪圖模組,用於創建各種圖表和視覺化數據
import matplotlib.pyplot as plt

# PyTorch 的計算機視覺工具包,包含了常用的數據集、模型和圖像轉換操作
import torchvision

# 一個用於數值計算的 Python 庫,提供了高效的數組和矩陣操作功能
import numpy as np

# PyTorch 的函數式神經網絡操作模組,包含了神經網絡中常用的操作,例如激活函數、損失函數等
import torch.nn.functional as F

# 提供 PyTorch 模型的詳細摘要信息,包括層數、參數數量和輸出形狀等,類似於 Keras 的 model.summary()
from torchinfo import summary

# 隱藏警告
import warnings

三、環境準備

# 忽略警告訊息
warnings.filterwarnings("ignore")  

# 輸出 PyTorch 的版本號
print(torch.__version__)

# 檢查是否有可用的CUDA設備,否則使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 打印出當前使用的設備
print(device)


四、載入數據

# 載入CIFAR-10訓練數據集
train_ds = torchvision.datasets.CIFAR10(
    'data',  # 數據下載後保存的目錄
    train=True,   # 指定載入訓練數據集
    transform=torchvision.transforms.ToTensor(), # 將圖像轉換為Tensor
    download=True # 如果數據集不存在,則從網絡下載
)

test_ds  = torchvision.datasets.CIFAR10(
    'data',   # 數據下載後保存的目錄
    train=False,   # 指定載入測試數據集
    transform=torchvision.transforms.ToTensor(),   # 將圖像轉換為Tensor
    download=True  # 如果數據集不存在,則從網絡下載
)


五、數據預處理

# 定義每個批次的大小為32
batch_size = 32

# 創建訓練數據的DataLoader
train_dl = torch.utils.data.DataLoader(
    train_ds,   # 訓練數據集
    batch_size=batch_size,   # 每個批次包含32個樣本
    shuffle=True # 在每個epoch開始時打亂數據
)

# 創建測試數據的DataLoader
test_dl  = torch.utils.data.DataLoader(
    test_ds, # 測試數據集
    batch_size=batch_size # 每個批次包含32個樣本
    # 測試數據集不需要shuffle,默認為False
)

# 從訓練數據加載器中取出一個批次的圖像和標籤
imgs, labels = next(iter(train_dl))
# 打印圖像的形狀 (batch_size, channels, height, width)
print(imgs.shape)


六、圖片可視化

# 創建一個大小為 (20, 5) 的圖形
plt.figure(figsize=(20, 5)) 
# 遍歷前20個圖像
for i, imgs in enumerate(imgs[:20]):
    # 將圖像從 (channels, height, width) 轉換為 (height, width, channels) 以便於顯示
    npimg = imgs.numpy().transpose((1, 2, 0))
    # 在2行10列的子圖中繪製圖像
    plt.subplot(2, 10, i+1)
    # 顯示圖像,使用灰度色彩映射
    plt.imshow(npimg, cmap=plt.cm.binary)
    # 隱藏坐標軸
    plt.axis('off')
# 顯示圖形
plt.show()


七、定義模型

# 定義分類數量(CIFAR-10有10個類別)
num_classes = 10 

# 定義模型類
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 定義第一個卷積層,輸入通道為3(CIFAR-10圖像的RGB通道),輸出通道為64,卷積核大小為3x3
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)  
        # 定義第一個池化層,使用2x2的最大池化
        self.pool1 = nn.MaxPool2d(kernel_size=2)       
        # 定義第二個卷積層,輸入通道為64,輸出通道為64,卷積核大小為3x3
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)   
        # 定義第二個池化層,使用2x2的最大池化
        self.pool2 = nn.MaxPool2d(kernel_size=2) 
        # 定義第三個卷積層,輸入通道為64,輸出通道為128,卷積核大小為3x3
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)    
        # 定義第三個池化層,使用2x2的最大池化
        self.pool3 = nn.MaxPool2d(kernel_size=2) 
                 
        # 定義第一個全連接層,輸入大小為512,輸出大小為256
        self.fc1 = nn.Linear(512, 256)          
        # 定義第二個全連接層,輸出大小為 num_classes(分類的數量)
        self.fc2 = nn.Linear(256, num_classes)
        
    # 定義前向傳播
    def forward(self, x):
        # 通過第一個卷積層、ReLU激活函數和池化層
        x = self.pool1(F.relu(self.conv1(x)))     
        # 通過第二個卷積層、ReLU激活函數和池化層
        x = self.pool2(F.relu(self.conv2(x)))
        # 通過第三個卷積層、ReLU激活函數和池化層
        x = self.pool3(F.relu(self.conv3(x)))
        
        # 將特徵圖展平為一維向量
        x = torch.flatten(x, start_dim=1)
        # 通過第一個全連接層和ReLU激活函數
        x = F.relu(self.fc1(x))
        # 通過第二個全連接層,輸出為分類結果
        x = self.fc2(x)
       
        # 返回輸出
        return x

# 將模型移動到指定設備(GPU或CPU)
model = Model().to(device)

# 打印模型結構摘要
summary(model)


八、定義訓練函數

# 定義損失函數為交叉熵損失
loss_fn = nn.CrossEntropyLoss() 

# 設定學習率為0.01
learn_rate = 1e-2 

# 使用隨機梯度下降(SGD)優化器
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

# 定義訓練函數
def train(dataloader, model, loss_fn, optimizer):
    # 獲取訓練集的大小
    size = len(dataloader.dataset)  
    
    # 獲取批次數量
    num_batches = len(dataloader) 

    # 初始化訓練損失和準確率
    train_loss, train_acc = 0, 0  
    
    # 遍歷訓練數據
    for X, y in dataloader: 
        # 將輸入和標籤移動到指定設備(GPU或CPU)
        X, y = X.to(device), y.to(device)
        
        # 前向傳播:計算模型預測
        pred = model(X)         
        
        # 計算損失
        loss = loss_fn(pred, y)
        
        # 反向傳播前清零梯度
        optimizer.zero_grad()  
        
        # 反向傳播:計算梯度
        loss.backward()        
        
        # 更新模型參數
        optimizer.step()     
        
        # 計算訓練準確率
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        
        # 累加訓練損失
        train_loss += loss.item()
            
    # 計算平均訓練準確率
    train_acc /= size
    
    # 計算平均訓練損失
    train_loss /= num_batches

    return train_acc, train_loss

九、定義測試函數

def test(dataloader, model, loss_fn):
    # 獲取測試集的大小
    size = len(dataloader.dataset) 
    
    # 獲取批次數量
    num_batches = len(dataloader)  
    
    # 初始化測試損失和準確率
    test_loss, test_acc = 0, 0
    
    # 禁用梯度計算(加速推理過程)
    with torch.no_grad():
        # 遍歷測試數據
        for imgs, target in dataloader:
            # 將輸入和標籤移動到指定設備(GPU或CPU)
            imgs, target = imgs.to(device), target.to(device)
            
            # 前向傳播:計算模型預測
            target_pred = model(imgs)
            
            # 計算損失
            loss = loss_fn(target_pred, target)
            
            # 累加測試損失
            test_loss += loss.item()
            
            # 計算測試準確率
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    # 計算平均測試準確率
    test_acc /= size
    
    # 計算平均測試損失
    test_loss /= num_batches

    return test_acc, test_loss

十、模型訓練

epochs     = 10   # 訓練的回合數
train_loss = []   # 存儲每個回合的訓練損失
train_acc  = []   # 存儲每個回合的訓練準確率
test_loss  = []   # 存儲每個回合的測試損失
test_acc   = []   # 存儲每個回合的測試準確率

# 訓練和測試循環
for epoch in range(epochs):
    model.train()  # 將模型設置為訓練模式
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  # 訓練模型並返回訓練準確率和損失
    
    model.eval()  # 將模型設置為評估模式
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  # 測試模型並返回測試準確率和損失
    
    train_acc.append(epoch_train_acc)  # 存儲訓練準確率
    train_loss.append(epoch_train_loss)  # 存儲訓練損失
    test_acc.append(epoch_test_acc)  # 存儲測試準確率
    test_loss.append(epoch_test_loss)  # 存儲測試損失
    
    # 格式化並輸出當前回合的訓練和測試結果
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')


十一、結果可視化

# 設定 Matplotlib 的參數以支援中文和負號的顯示
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False       # 用來正常顯示負號
plt.rcParams['figure.dpi']         = 100         # 設定圖表的解析度

epochs_range = range(epochs)  # 訓練回合的範圍

plt.figure(figsize=(12, 3))  # 設置圖表大小
plt.subplot(1, 2, 1)  # 創建一個2行1列的子圖佈局,這裡畫第一個圖

# 畫出訓練和測試的準確率曲線
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')  # 顯示圖例
plt.title('Training and Validation Accuracy')  # 設定標題

plt.subplot(1, 2, 2)  # 畫第二個圖
# 畫出訓練和測試的損失曲線
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')  # 顯示圖例
plt.title('Training and Validation Loss')  # 設定標題

plt.show()  # 顯示圖表


十二、心得

最近開始使用 PyTorch 框架來訓練模型,以下是 PyTorch 和 TensorFlow 2 的差異說明

  • 動態圖 vs 靜態圖

    • PyTorch 使用動態計算圖:這意味著計算圖是即時定義的,每次迭代可以根據需要改變結構,更靈活,更容易進行調試和編程
    • TensorFlow 2 則引入了即時執行(Eager Execution),類似於 PyTorch 的動態圖模式,使得模型的建立和調試更加直觀和靈活,此外,TensorFlow 2 也可以使用靜態圖來進行更高效的低級優化和部署
  • API 設計

    • PyTorch 的 API 設計更直觀和簡潔,更貼近 Python 編程風格,使得學習曲線較平緩,特別適合研究和實驗
    • TensorFlow 2 則採用了 Keras 作為其主要高級 API,提供了更高層次的抽象和簡化,使得模型的定義和訓練更加容易,特別適合生產環境和大規模部署
  • 模型部署

    • TensorFlow 2 在模型訓練後的部署和生產環境中表現更加優異,支持較多的低級優化和部署工具(如 TensorFlow Serving)
    • PyTorch 雖然近年來在這方面有所改進,但相對而言仍有一定的差距,部署需要更多的自定義和額外的工作
  • 社區和生態系統

    • TensorFlow 擁有更大的社區支持和更成熟的生態系統,有更多的文檔、教程和預訓練模型可用
    • PyTorch 的社區雖然較小,但在學術界和研究領域中得到了廣泛的應用和支持,並且快速增長

這次作業中,我學到如何使用 PyTorch 的 torchvision 模組來加載 CIFAR-10 數據集,這個過程包括對圖片進行標準化和轉換,使其適合訓練模型,載入數據後使用 PyTorch 定義一個簡單的卷積神經網絡(CNN)來處理圖像分類任務,這裡使用了幾個卷積層和池化層,以及全連接層,之後定義損失函數和優化器,然後訓練模型,這裡使用交叉熵損失和隨機梯度下降(SGD)優化器,最後使用測試集來評估模型的表現,計算準確率和其他指標

通過調整不同的超參數和嘗試不同的模型架構,也意識到了如何優化模型以達到更好的性能,這個過程不僅加深了我對深度學習的理解,還增強了我解決實際問題的能力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1868552.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ubuntu 18.04 server源码编译安装freeswitch 1.10.7支持音视频通话、收发短信——筑梦之路

软件版本说明 ubuntu版本18.04:https://releases.ubuntu.com/18.04.6/ubuntu-18.04.6-live-server-amd64.iso freeswitch 版本1.10.7:https://files.freeswitch.org/freeswitch-releases/freeswitch-1.10.7.-release.tar.gz spandsp包:https:…

家人们谁懂啊?手机信息删除找不回,原来3个技巧就能恢复

你是否也曾经遇到过这样的情况:手机里的重要信息不小心删除了,翻遍所有地方都找不到,心情烦躁到了极点。其实,信息删除后找回它们并不像你想象的那么复杂。不要担心,因为今天我将与你分享3个技巧,帮助你轻松…

1 哈希应用

O(1) 的哈希 Python中的哈希表主要通过内置的字典(dict)类型实现。对于字典的操作,包括插入(insert)、删除(delete)和查找(lookup)的时间复杂度,在理想情况下…

上位机图像处理和嵌入式模块部署(mcu之iap升级)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 mcu种类很多,如果是开发的时候需要对固件升级,整体还是比较容易的。不管是dap,还是st-link v2、j-link&#xf…

idea中使用springboot进行开发时遇到的工程结构问题汇总

idea中的工程结构和eclipse中不同,但是配置的内容都是一样的。 IDEA中也就是这个页面,快捷键ctrlaltshifts 如果在eclipse中,经常会遇到jre和jdk不正确的情况,但IDEA中这个问题很少,但是IDEA中会经常由于未正常配置根…

ESXI存储设备已经分区,无法创建数据存储。

问题:ESXi 存储设备已经分区完成,并且有 VMFS 文件系统,无法创建数据存储,选项是灰色。 解决办法:通过命令行工具 在现有的 VMFS 分区上创建数据存储。 在现有的 VMFS 分区上创建数据存储 1.ESXI开SSH2.windows自带CMD登入ESXI。&…

使用宝塔安装ModstartCMS (非一键安装)

操作系统 Linux Windows 推荐 Linux 操作系统,性能比较好 软件环境 稳定版 PHP 5.6 PHP 7.0 MySQL >5.0 PHP Extension:Fileinfo Apache/Nginx Laravel 9.0 版本 PHP 8.1 MySQL >5.0 PHP Extension:Fileinfo Apache/Ngin…

Windows 注册表是什么?如何备份注册表?

Windows注册表(Windows Registry)是微软Windows操作系统中的一个重要组件,用于存储系统和应用程序的配置信息和选项。下面就给大家详细讲解一下什么是注册表。 注册表的概念 Windows 注册表是一个集中管理的数据库,存储了系统、…

安霸CVFlow推理开发笔记

一、安霸环境搭建: 1.远程172.20.62.13 2. 打开Virtualbox,所在目录:E:\Program Files\Oracle\VirtualBox 3. 配置好ubuntu18.04环境,Ubuntu密码:amba 4. 安装toolchain,解压Ambarella_Toolchain_CNNGe…

成都工业学院2022级数据库原理及应用专周课程学生选课系统(基础篇)

运行环境 操作系统:Windows 11 家庭版 运行软件:Navicat Premium 16 项目内容 需求分析 学生:选课、退课、查看课程信息、查看选课情况等操作 教师:查看选课名单等操作 管理员:课程管理等操作 实体关系模式图 关…

js计算某个时间距离当前时间多少天,少于7天红色展示

效果图 后端返回数据格式 info:{vip_validity:"2027-09-07" }<div>到期时间&#xff1a;{{ info.vip_validity }}, 剩余<span :class"countdownDays(info.vip_validity) < 7 ? surplus : ">{{ !!info.vip_validity ? countdownDays(inf…

MySQL进阶-索引-使用规则-索引失效情况一(索引列运算,字符串不加引号,头部模糊匹配)

文章目录 1、索引列运算1.1、查询表tb_user1.2、查看tb_user的索引1.3、查询 phone177999900151.4、执行计划 phone177999900151.5、查询 substring(phone,10,2) 151.6、执行计划 substring(phone,10,2) 15 2、字符串不加引号2.1、查询 phone177999900152.2、执行计划 phone177…

强化学习详解:理论基础与核心算法解析

本文详细介绍了强化学习的基础知识和基本算法&#xff0c;包括动态规划、蒙特卡洛方法和时序差分学习&#xff0c;解析了其核心概念、算法步骤及实现细节。 关注TechLead&#xff0c;复旦AI博士&#xff0c;分享AI领域全维度知识与研究。拥有10年AI领域研究经验、复旦机器人智能…

【漏洞复现】用友GRP-U8——SQL注入

声明&#xff1a;本文档或演示材料仅供教育和教学目的使用&#xff0c;任何个人或组织使用本文档中的信息进行非法活动&#xff0c;均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 用友GRP-U8是一款企业管理软件&#xff0c;其系统dialog_moreUs…

客户案例|某 SaaS 企业租户敏感数据保护实践

近年来&#xff0c;随着云计算技术的快速发展&#xff0c;软件即服务&#xff08;SaaS&#xff09;在各行业的应用逐渐增多&#xff0c;SaaS 应用给企业数字化发展带来了便捷性、成本效益与可访问性&#xff0c;同时也带来了一系列数据安全风险。作为 SaaS 产品运营服务商&…

活动|华院计算受邀参加2024全球人工智能技术大会(GAITC),探讨法律大模型如何赋能社会治理

6月22至23日&#xff0c;备受瞩目的2024全球人工智能技术大会&#xff08;GAITC&#xff09;在杭州市余杭区未来科技城隆重举行。本届大会以“交叉、融合、相生、共赢”为主题&#xff0c;集“会、展、赛”为一体&#xff0c;聚“产、学、研”于一堂。值得一提的是&#xff0c;…

论文速览 | IEEE Signal Processing Letters, 2024 | 基于时空上下文学习的事件相机立体深度估计

论文速览 | IEEE Signal Processing Letters, 2024 | 基于时空上下文学习的事件相机立体深度估计 1 引言 在计算机视觉领域,立体深度估计一直是一个备受关注的研究热点。传统的基于帧的方法虽然取得了长足的进步,但在处理运动模糊、低照度和平坦区域等挑战性场景时仍面临诸多…

203. 移除链表元素【链表】【C++】

题目描述 题目描述 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,6,3,4,5,6], val 6 输出&#xff1a;[1,2,3,4,5] 示例 2&#x…

RabbitMQ的Direct交换机

Direct交换机 BindingKey 在Fanout模式中&#xff0c;一条消息&#xff0c;会被所有订阅的队列都消费。但是&#xff0c;在某些场景下&#xff0c;我们希望不同的消息被不同的队列消费。这时就要用到Direct类型的Exchange。 在Direct模型下&#xff1a; 队列与交换机的绑定&a…

爬虫是什么?

目录 1.什么是互联网爬虫&#xff1f; 2.爬虫核心? 3.爬虫的用途? 4.爬虫分类&#xff1f; 5.反爬手段&#xff1f; 1.什么是互联网爬虫&#xff1f; 如果我们把互联网比作一张大的蜘蛛网&#xff0c;那一台计算机上的数据便是蜘蛛网上的一个猎物&#xff0c;而爬虫程序…