如何通过卷积神经网络(CNN)有效地提取图像的局部特征,并在CIFAR-10数据集上实现高精度的分类?

news2025/3/6 21:52:48

目录

1. CNN 提取图像局部特征的原理

2. 在 CIFAR - 10 数据集上实现高精度分类的步骤

2.1 数据准备

2.2 构建 CNN 模型

2.3 定义损失函数和优化器

2.4 训练模型

2.5 测试模型

3. 提高分类精度的技巧


卷积神经网络(Convolutional Neural Network, CNN)是专门为处理具有网格结构数据(如图像)而设计的深度学习模型,能够有效地提取图像的局部特征。下面将详细介绍如何通过 CNN 提取图像局部特征,并在 CIFAR - 10 数据集上实现高精度分类,同时给出基于 PyTorch 的示例代码。

1. CNN 提取图像局部特征的原理

  • 卷积层:卷积层是 CNN 的核心组件,它通过使用多个卷积核(滤波器)在图像上滑动进行卷积操作。每个卷积核可以看作是一个小的矩阵,用于检测图像中的特定局部特征,如边缘、纹理等。卷积操作会生成一个特征图,特征图上的每个元素表示卷积核在对应位置检测到的特征强度。
  • 局部连接:CNN 中的神经元只与输入图像的局部区域相连,而不是像全连接网络那样与所有输入神经元相连。这种局部连接方式使得网络能够专注于提取图像的局部特征,减少了参数数量,提高了计算效率。
  • 权值共享:在卷积层中,同一个卷积核在整个图像上共享一组权重。这意味着卷积核在不同位置检测到的特征是相同的,进一步减少了参数数量,同时增强了网络对平移不变性的学习能力。
  • 池化层:池化层通常紧跟在卷积层之后,用于对特征图进行下采样,减少特征图的尺寸,降低计算量,同时增强特征的鲁棒性。常见的池化操作有最大池化和平均池化。

2. 在 CIFAR - 10 数据集上实现高精度分类的步骤

2.1 数据准备

CIFAR - 10 数据集包含 10 个不同类别的 60000 张 32x32 彩色图像,其中训练集 50000 张,测试集 10000 张。可以使用 PyTorch 的torchvision库来加载和预处理数据。

import torch
import torchvision
import torchvision.transforms as transforms

# 定义数据预处理步骤
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
2.2 构建 CNN 模型

可以构建一个简单的 CNN 模型,包含卷积层、池化层和全连接层。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(x)
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()
2.3 定义损失函数和优化器

使用交叉熵损失函数和随机梯度下降(SGD)优化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
2.4 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

for epoch in range(20):  # 训练20个epoch
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')
            running_loss = 0.0

print('Finished Training')
2.5 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

3. 提高分类精度的技巧

  • 数据增强:通过随机裁剪、翻转、旋转等操作增加训练数据的多样性,提高模型的泛化能力。
  • 更深的网络结构:可以使用更复杂的 CNN 架构,如 ResNet、VGG 等,这些网络通过引入残差连接、批量归一化等技术,能够更好地学习图像特征。
  • 学习率调整:在训练过程中动态调整学习率,如使用学习率衰减策略,使模型在训练初期快速收敛,后期更精细地调整参数。
  • 正则化:使用 L1 或 L2 正则化、Dropout 等技术防止模型过拟合。

通过以上步骤和技巧,可以有效地利用 CNN 提取图像的局部特征,并在 CIFAR - 10 数据集上实现高精度的分类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2310717.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis的持久化-RDBAOF

文章目录 一、 RDB1. 触发机制2. 流程说明3. RDB 文件的处理4. RDB 的优缺点 二、AOF1. 使用 AOF2. 命令写⼊3. 文件同步4. 重写机制5 启动时数据恢复 一、 RDB RDB 持久化是把当前进程数据生成快照保存到硬盘的过程,触发 RDB 持久化过程分为手动触发和自动触发。 …

Redis 的几个热点知识

前言 Redis 是一款内存级的数据库,凭借其卓越的性能,几乎成为每位开发者的标配工具。 虽然 Redis 包含大量需要掌握的知识,但其中的热点知识并不多。今天,『知行』就和大家分享一些 Redis 中的热点知识。 Redis 数据结构 Redis…

靶场之路-VulnHub-DC-6 nmap提权、kali爆破、shell反连

靶场之路-VulnHub-DC-6 一、信息收集 1、扫描靶机ip 2、指纹扫描 这里扫的我有点懵,这里只有两个端口,感觉是要扫扫目录了 nmap -sS -sV 192.168.122.128 PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 7.4p1 Debian 10deb9u6 (protoc…

机器视觉开发教程——封装Halcon通用模板匹配工具【含免费教程源码】

目录 引言前期准备Step1 设计可序列化的输入输出集合【不支持多线程】Step2 设计程序框架1、抽象层【IProcess】2、父类【HAlgorithm】3、子类【HFindModelTool】 Step3 设计UI结果展示 引言 通过仿照VisionPro软件二次开发Halcon的模板匹配工具,便于在客户端软件中…

【3DMAX室内设计】2D转3D平面图插件2Dto3D使用方法

【一键筑梦】革新性2Dto3D插件,轻松实现2D平面图向3D空间的华丽蜕变。这款专为3DMAX室内设计师设计的神器,集一键式墙体、门、窗自动生成功能于一身,能够将2D图形无缝转化为3D网格对象(3D平面图、鸟瞰图),一…

vscode 查看3d

目录 1. vscode-3d-preview obj查看ok 2. vscode-obj-viewer 没找到这个插件: 3. 3D Viewer for Vscode 查看obj失败 1. vscode-3d-preview obj查看ok 可以查看obj 显示过程:开始是绿屏,过了1到2秒,后来就正常看了。 2. vsc…

自动驾驶---不依赖地图的大模型轨迹预测

1 前言 早期传统自动驾驶方案通常依赖高精地图(HD Map)提供道路结构、车道线、交通规则等信息,可参考博客《自动驾驶---方案从有图迈进无图》,本质上还是存在问题: 数据依赖性高:地图构建成本昂贵&#xf…

perl初试

我手头有一个脚本,用于从blastp序列比对的结果文件中,进行文本处理, 获取序列比对最优的hit记录 #!/usr/bin/perl -w use strict;my ($blast_out) ARGV; my $usage "This script is to get the best hit from blast output file wit…

VS Code C++ 开发环境配置

VS Code 是当前非常流行的开发工具. 本文讲述如何配置 VS Code 作为 C开发环境. 本文将按照如下步骤来介绍如何配置 VS Code 作为 C开发环境. 安装编译器安装插件配置工作区 第一个步骤的具体操作会因为系统不同或者方案不同而有不同的选择. 环境要求 首先需要立即 VS Code…

Web Snapshot 网页截图 模块代码详解

本文将详细解析 Web Snapshot 模块的实现原理和关键代码。这个模块主要用于捕获网页完整截图,特别优化了对动态加载内容的处理。 1. 模块概述 snapshot.py 是一个功能完整的网页截图工具,它使用 Selenium 和 Chrome WebDriver 来模拟真实浏览器行为&am…

Windows 10 下 SIBR Core (i.e. 3DGS SIBR Viewers) 的编译

本文针对在 Windows 10 上从源码编译安装3DGS (3D Gaussian Splatting)的Viewers 即SIBR Core及外部依赖库extlibs(预编译的版本直接在页面https://sibr.gitlabpages.inria.fr/download.html下载) ,参考SIBR 的官方网站…

JavaWeb-HttpServletRequest请求域接口

文章目录 HttpServletRequest请求域接口HttpServletRequest请求域接口简介关于请求域和应用域的区别 请求域接口中的相关方法获取前端请求参数(getParameter系列方法)存储请求域名参数(Attribute系列方法)获取客户端的相关地址信息获取项目的根路径 关于转发和重定向的细致剖析…

防火墙虚拟系统实验

拓扑图 需求一 安全策略要求: 1、只存在一个公网IP地址,公司内网所有部门都需要借用同一个接口访问外网 2、财务部禁止访问Internet,研发部门只有部分员工可以访问Internet,行政部门全部可以访问互联网 3、为三个部门的虚拟系统分…

点云滤波方法:特点、作用及使用场景

点云滤波是点云数据预处理的重要步骤,目的是去除噪声点、离群点等异常数据,平滑点云或提取特定频段特征,为后续的特征提取、配准、曲面重建、可视化等高阶应用打下良好基础。以下是点云中几种常见滤波方法的特点、作用及使用场景:…

Gradle 配置 Lombok 项目并发布到私有 Maven 仓库的完整指南

Gradle 配置 Lombok 项目并发布到私有 Maven 仓库的完整指南 在 Java 项目开发中,使用 Lombok 可以极大地减少样板代码(如 getter/setter 方法、构造器等),提高开发效率。然而,当使用 Gradle 构建工具并将项目发布到私…

ArcGIS Pro 基于基站数据生成基站扇区地图

在当今数字化的时代,地理信息系统(GIS)在各个领域都发挥着至关重要的作用。 ArcGIS Pro作为一款功能强大的GIS软件,为用户提供了丰富的工具和功能,使得数据处理、地图制作和空间分析变得更加高效和便捷。 本文将为您…

【Python · Pytorch】Conda介绍 DGL-cuda安装

本文仅涉及DGL库介绍与cuda配置,不包含神经网络及其训练测试。 起因:博主电脑安装了 CUDA 12.4 版本,但DGL疑似没有版本支持该CUDA版本。随即想到可利用Conda创建CUDA12.1版本的虚拟环境。 1. Conda环境 1.1 Conda环境简介 Conda&#xff1…

leetcode:2965. 找出缺失和重复的数字(python3解法)

难度:简单 给你一个下标从 0 开始的二维整数矩阵 grid,大小为 n * n ,其中的值在 [1, n2] 范围内。除了 a 出现 两次,b 缺失 之外,每个整数都 恰好出现一次 。 任务是找出重复的数字a 和缺失的数字 b 。 返回一个下标从…

Android U 分屏——SystemUI侧处理

WMShell相关的dump命令 手机分屏启动应用后运行命令:adb shell dumpsys activity service SystemUIService WMShell 我们可以找到其中分屏的部分,如下图所示: 分屏的组成 简图 分屏是由上分屏(SideStage)、下分屏(MainStage)以及分割线组…

flink集成tidb cdc

Flink TiDB CDC 详解 1. TiDB CDC 简介 1.1 TiDB CDC 的核心概念 TiDB CDC 是 TiDB 提供的变更数据捕获工具,能够实时捕获 TiDB 集群中的数据变更(如 INSERT、UPDATE、DELETE 操作),并将这些变更以事件流的形式输出。TiDB CDC 的…