【Pytorch】计算机视觉项目——卷积神经网络CNN模型识别图像分类

news2025/1/7 9:05:21

目录

  • 一、前言
  • 二、CNN可视化解释器
    • 1. 卷积层工作原理
  • 三、详细步骤说明
    • 1. 数据集准备
    • 2.DataLoader
    • 3. 搭建模型CNN
      • 3.1 设置设备
      • 3.2 搭建CNN模型
      • 3.3 设置loss 和 optimizer
      • 3.4 训练和测试循环
    • 4. 模型评估和结果输出


一、前言

在上一篇笔记《【Pytorch】整体工作流程代码详解(新手入门)》中介绍了Pytorch的整体工作流程,本文继续说明如何使用Pytorch搭建卷积神经网络(CNN模型)来给图像分类。

其他相关文章:
深度学习入门笔记:总结了一些神经网络的基础概念。
TensorFlow专栏:《计算机视觉入门系列》介绍如何用TensorFlow框架实现卷积分类器。


二、CNN可视化解释器

卷积分类器,是通过将导入的图片,一层层筛选、过滤、学习图形的特征,最后实现对输入数据的分类、识别或预测。

下面是github上一个CNN可交互的可视化解释器( 链接在此)。

在这里插入图片描述
从上面的全局图可见,CNN模型由多个卷积层和池化层交替堆叠而成,图片被进行了不同的处理,每一层都被提取出了不同的特征,最终将每个单元的输出汇总,输出分类。

1. 卷积层工作原理

相当于是有一个滤镜格子(叫做卷积核或者滤波器),从左到右、从上至下地扫描整个输入图层,并生成新的图层。
在这里插入图片描述
整个过程中会压缩数据,如下图所示,将一个3*3 的图形,压缩成一个格子。
在这里插入图片描述

参数说明
Input输入数据,中间是一个44的格子
Padding, 外面加的一圈格子,加一个单元。
Kernel Size卷积核大小:这里是3
3 ,左手边的红格子
Stride 步长:卷积核每次走多少格子
在这里插入图片描述


三、详细步骤说明

1. 数据集准备

import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
print(f"Pytorch version:{torch.__version__}\n torchvision version:{torchvision.__version__}")

在这里插入图片描述
数据集介绍:
FashionMNIST是torchvision自带的一个图像数据集,用于机器学习和计算机视觉的训练和测试。它包含了10个不同类别的服装物品的灰度图像,包括T恤、裤子、套衫、裙子、外套、凉鞋、衬衫、运动鞋、包和短靴。每张图片的分辨率是28x28像素。

train_data=datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=None
)

test_data=datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform=None
)

#数据集查看
image, label = train_data[0]
# image, label #查看第一条训练数据
image.shape  #查看数据的形状

在这里插入图片描述
图像张量的形状是[1, 28, 28],或者说:[颜色=1,高度=28,宽度=28]

# 查看类别
class_names = train_data.classes
class_names

在这里插入图片描述

#图形可视化
import matplotlib.pyplot as plt
image, label = train_data[0]
print(f"Image shape: {image.shape}")
plt.imshow(image.squeeze()) 
plt.title(label);

在这里插入图片描述

2.DataLoader

from torch.utils.data import DataLoader

# 设置批处理大小超参数
BATCH_SIZE = 32

# 将数据集转换为可迭代的(批处理)
train_dataloader = DataLoader(train_data,
    batch_size=BATCH_SIZE,  # 每个批次有多少样本?
    shuffle=True  # 是否随机打乱?
)

test_dataloader = DataLoader(test_data,
    batch_size=BATCH_SIZE,
    shuffle=False  # 测试数据集不一定需要洗牌
)

#打印结果
print(f"Dataloaders: {train_dataloader, test_dataloader}")
print(f"Length of train dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
print(f"Length of test dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}")

在这里插入图片描述

参数介绍:
shuffle:指对数据集进行随机打乱,以便在训练模型时以随机顺序呈现数据。这样做有助于提高模型的泛化能力并减少模型对输入数据顺序的依赖性。相反,对于测试数据集通常被设置为False,因为在评估模型性能时,我们希望保持数据的原始顺序,以便能够正确评估模型在真实数据上的表现。

train_features_batch, train_labels_batch = next(iter(train_dataloader))
train_features_batch.shape, train_labels_batch.shape

在这里插入图片描述

3. 搭建模型CNN

3.1 设置设备

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
device

在GPU上跑。
在这里插入图片描述

3.2 搭建CNN模型

回顾一下CNN的参数设置:

  1. in_channels:输入数据的通道数,对于二维卷积,表示输入图像或特征图的深度或通道数。
  2. out_channels:输出的通道数,即卷积核的数量。每个卷积核生成一个输出通道。
  3. kernel_size:卷积核的大小或滤波器的大小,用整数或元组表示,指定了卷积核的高度和宽度。kernel_size=3意味着卷积核的高度和宽度均为3。
  4. stride:卷积核滑动的步长,决定卷积核在输入数据上滑动的距离。stride=1表示卷积核在输入上每次滑动1个步长。
  5. padding:在输入数据周围填充0的层数。填充有助于保持输入和输出尺寸相同,特别是在卷积层之间传递信息时。这里的padding=1表示在输入数据周围填充一层0,以保持卷积操作后尺寸不变。
# Create a convolutional neural network
class FashionMNISTModelV2(nn.Module):
    def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
        super().__init__()
        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape,
                      out_channels=hidden_units,
                      kernel_size=3, 
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2) 
        )
        self.block_2 = nn.Sequential(
            nn.Conv2d(hidden_units, hidden_units, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_units*7*7,
                      out_features=output_shape)
        )

    def forward(self, x: torch.Tensor):
        x = self.block_1(x)
        # print(x.shape)
        x = self.block_2(x)
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        return x

# 加入参数
torch.manual_seed(42)
model_2 = FashionMNISTModelV2(input_shape=1,
    hidden_units=10,
    output_shape=len(class_names)).to(device)
model_2

在这里插入图片描述

3.3 设置loss 和 optimizer

导入accurcay_fn辅助函数文件

import requests
from pathlib import Path

# 从Learn PyTorch存储库中下载辅助函数(如果尚未下载)
if Path("helper_functions.py").is_file():
  print("helper_functions.py已存在,跳过下载")
else:
  print("正在下载helper_functions.py")
  # 注意:你需要使用"raw" GitHub URL才能使其工作
  request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
  with open("helper_functions.py", "wb") as f:
    f.write(request.content)

创建loss、accuracy和optimizer

from helper_functions import accuracy_fn

# 设置loss和optimizer
loss_fn = nn.CrossEntropyLoss() # this is also called "criterion"/"cost function" in some places
optimizer = torch.optim.SGD(params=model_0.parameters(), lr=0.1)

创建一个计时器

from timeit import default_timer as timer
def print_train_time(start:float,end:float,device:torch.device=None):
  total_time=end-start
  print(f"Train time on {device}: {total_time:.3f} seconds")
  return total_time

3.4 训练和测试循环

def train_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               accuracy_fn,
               device: torch.device = device):
    train_loss, train_acc = 0, 0
    model.to(device)
    for batch, (X, y) in enumerate(data_loader):
        X, y = X.to(device), y.to(device)
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        train_loss += loss
        train_acc += accuracy_fn(y_true=y,
                                 y_pred=y_pred.argmax(dim=1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(data_loader)
    train_acc /= len(data_loader)
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")

def test_step(data_loader: torch.utils.data.DataLoader,
              model: torch.nn.Module,
              loss_fn: torch.nn.Module,
              accuracy_fn,
              device: torch.device = device):
    test_loss, test_acc = 0, 0
    model.to(device)
    model.eval() # put model in eval mode
    # Turn on inference context manager
    with torch.inference_mode():
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)
            test_pred = model(X)

            test_loss += loss_fn(test_pred, y)
            test_acc += accuracy_fn(y_true=y,
                y_pred=test_pred.argmax(dim=1) # Go from logits -> pred labels
            )
        test_loss /= len(data_loader)
        test_acc /= len(data_loader)
        print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")
torch.manual_seed(42)

from timeit import default_timer as timer
train_time_start_model_2 = timer()

epochs = 3
for epoch in tqdm(range(epochs)):
    print(f"Epoch: {epoch}\n---------")
    train_step(data_loader=train_dataloader,
        model=model_2,
        loss_fn=loss_fn,
        optimizer=optimizer,
        accuracy_fn=accuracy_fn,
        device=device
    )
    test_step(data_loader=test_dataloader,
        model=model_2,
        loss_fn=loss_fn,
        accuracy_fn=accuracy_fn,
        device=device
    )

train_time_end_model_2 = timer()
total_train_time_model_2 = print_train_time(start=train_time_start_model_2,
                                           end=train_time_end_model_2,
                                           device=device)

在这里插入图片描述

4. 模型评估和结果输出

torch.manual_seed(42)
def eval_model(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               accuracy_fn,
               device: torch.device = device):  #注意
    loss, acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for X, y in data_loader:
            #注意设备转移
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            loss += loss_fn(y_pred, y)
            acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

        
        loss /= len(data_loader)
        acc /= len(data_loader)
    return {"model_name": model.__class__.__name__, 
            "model_loss": loss.item(),
            "model_acc": acc}


model_2_results = eval_model(
    model=model_2,
    data_loader=test_dataloader,
    loss_fn=loss_fn,
    accuracy_fn=accuracy_fn
)
model_2_results

模型输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1175967.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mac电脑大旧型文件清理软件CleanMyMac2024

CleanMyMac的大旧文件模块会帮您定位、检查和移除您几个月没有打开过并且不再需要的大型文件和文件夹,这样可以节省更多的磁盘空间。 CleanMyMac X全新版下载如下: https://wm.makeding.com/iclk/?zoneid49983 大型和旧文件模块可以查找和移除大型文件和文件夹&…

香港账户的美金如何打到国内账户

香港账户的美金可以有多种方式打到国内账户,以下是几种常见的方式: 1.银行电汇:将美元转账到中国大陆的银行账户上并进行换汇操作,这是一种稳妥可靠的方式,但手续费相对较高。 2. 支付宝国际汇款:通过支付…

任正非说:我们要在整体上形成海军陆战队和主力作战团队相配合的作战方案。

你好!这是华研荟【任正非说】系列的第30篇文章,让我们聆听任正非先生的真知灼见,学习华为的管理思想和管理理念。 一、我们的业务量在增长,因此带来表面上人的效益是增长的。但是我们要看到,我们现在利润不是来自于管理…

C++ Qt 学习(三):无边框窗口设计

1. 无边框窗口 1.1 主窗口实现 MainWidget.h #pragma once#include <QtWidgets/QWidget> #include "CTitleBar.h" #include "CFrameLessWidgetBase.h"// 主窗口 MainWidget 继承自无边框窗口公用类 CFrameLessWidgetBase class MainWidget : publi…

全志R128应用开发案例——SPI驱动ST7789V1.3寸LCD

SPI驱动ST7789V1.3寸LCD R128 平台提供了 SPI DBI 的 SPI TFT 接口&#xff0c;具有如下特点&#xff1a; Supports DBI Type C 3 Line/4 Line Interface ModeSupports 2 Data Lane Interface ModeSupports data source from CPU or DMASupports RGB111/444/565/666/888 vide…

【LeetCode:318. 最大单词长度乘积 | 模拟 位运算】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

CRM软件如何高效培育销售线索?

​ 通过线索培育可以挖掘出更多CRM软件销售管道中的有价值客户提高销售业绩。但机遇与挑战总是共存的&#xff0c;培育线索要从不同的渠道执行大量重复性的操作&#xff0c;人为操控不仅速度慢而且容易出错&#xff0c;那么企业如何高效培育销售线索? 发送个性化邮件 我们知…

YOLO目标检测——汽车头部尾部检测数据集【含对应voc、coco和yolo三种格式标签】

实际项目应用&#xff1a;用于训练自动驾驶系统中的车辆感知模块&#xff0c;以实现对周围车辆头部和尾部的准确检测和识别数据集说明&#xff1a;汽车头部尾部检测数据集&#xff0c;真实场景的高质量图片数据&#xff0c;数据场景丰富标签说明&#xff1a;使用lableimg标注软…

随机森林在生物信息中的应用

今天与大家分享一项强大的机器学习算法随机森林。这个算法不仅在数据科学领域广泛应用&#xff0c;还在生物信息学中发挥了巨大的作用。 让我们一起探索随机森林的原理、优缺点以及它在生物信息领域的实际应用场景&#xff0c;本文将给出R语言进行应用的实际方法&#xff0c;利…

数据采集卡如何选型?

数据采集卡如何选型? 一、 确认采集任务二、 选择合适的传感器三、采样频率、分辨率、总线类型、量程等关键参数选择 一、 确认采集任务 二、 选择合适的传感器 三、采样频率、分辨率、总线类型、量程等关键参数选择 第1步&#xff1a;确认采集任务&#xff0c;电压&#x…

产业园区中工业厂房的能源综合配置——工业园区综合能源数字化系统建设方案

以下内容转自微信公众号&#xff1a;PPP产业大讲堂&#xff0c;《产业园区中工业厂房的能源综合配置》。 园区工业地产中能源综合配置存在的问题 我国园区工业地产建设已历经近40年的发展, 园区在区域经济发展、产业集聚方面发挥了重要的载体和平台作用, 有力推动了我国社会经…

未来商业趋势:无人奶柜的无限潜力

未来商业趋势&#xff1a;无人奶柜的无限潜力 随着自动售货机的普及和公共场所需求的多样化&#xff0c;无人奶柜作为一种新兴的自动售货机&#xff0c;开始出现在学校、医院、办公楼、商场等公共场所&#xff0c;为人们提供便捷、低成本的饮品购买服务。 这种无人奶柜不仅可以…

windows 11渗透测试工具箱

系统简介 本环境旨在提供一个开箱即用的windows渗透测试环境&#xff1b;建议运行环境&#xff1a;【vmware&#xff1a;17.0 】 /【运行内存&#xff1a;8G】 /【固态硬盘&#xff1a;100G】 Windows11 Penetration Suite Toolkit v2.2 (WSL) 【推荐】 下载链接&#xff1a;h…

Leetcode-448 找到数组中消失的数字

原理&#xff1a;每个num[i]对应一个数组下标&#xff0c;对所有num[i]下标对应的数变负以后&#xff0c;没有变负的数没有下表对应&#xff0c;这个下标对应的数就缺失&#xff08;好难想&#xff09;。把数组下标当成一个有序数列用&#xff0c;数组里面的元素正负性对数列标…

AI:61-基于深度学习的草莓病害识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

ClickHouse开发系列

一、 ClickHouse详解、安装教程_clickhouse源码安装 二、ClickHouse 语法详解_clickhouse讲解 三、ClickHouse SQL 操作语句详解 四、ClickHouse 高级教程—官方原版 五、ClickHouse主键索引最佳实践 六、MySQL与ClickHouse集成 七、ClickHouse 集成MongoDB、Re…

如何通过一条数字人三维动画宣传片,打造出数字文旅

越来越多虚拟人&#xff0c;以文化挖掘者的身份通过数字人三维动画宣传片&#xff0c;打通次元壁&#xff0c;助力文化传播形式创造性转化、创新性表达&#xff0c;赋予文化发展新动能。 如南方都市报民间博物馆文化探寻者“岭梅香”&#xff0c;由一艘在南宋时期失事的沉船“南…

基于原子轨道搜索算法的无人机航迹规划-附代码

基于原子轨道搜索算法的无人机航迹规划 文章目录 基于原子轨道搜索算法的无人机航迹规划1.原子轨道搜索搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要&#xff1a;本文主要介绍利用原子轨道搜索算法来优化无…

干货满满,mac屏幕录制实用教程!

在当今科技飞速发展的时代&#xff0c;屏幕录制已经成为了人们日常生活中经常使用的功能&#xff0c;无论是工作还是生活&#xff0c;我们都需要使用到屏幕录制软件来捕捉屏幕上的内容。mac作为苹果公司开发的操作系统&#xff0c;拥有许多内置的屏幕录制工具。本文将详细介绍两…

录屏有声音吗?解答你的疑惑

录屏是我们在工作和生活中经常遇到的需求&#xff0c;有时候我们需要记录下屏幕的操作或展示给别人看。然而&#xff0c;很多人在录屏的时候都会遇到一个问题&#xff1a;录制的视频没有声音。那么&#xff0c;录屏有声音吗&#xff1f;答案是肯定的。在本文中&#xff0c;我们…