【深度学习】—CNN卷积神经网络 从原理到实现

news2024/11/30 9:30:32

卷积神经网络(CNN)从原理到实现

什么是卷积神经网络(CNN)?

卷积神经网络(Convolutional Neural Network, CNN)是一种深度学习模型,主要应用于图像分类、目标检测和自然语言处理等领域。与传统神经网络不同,CNN 通过局部感受野、权值共享和池化操作有效减少参数量,同时保留空间信息。

CNN 结构图

在这里插入图片描述


CNN 的核心概念

CNN 包括三大核心操作:卷积激活函数池化

1. 卷积层(Convolution Layer)

卷积层的目的是提取特征,通过卷积核(Filter)对输入进行特征提取。卷积的数学公式如下:

y [ i , j ] = ∑ m = 0 k − 1 ∑ n = 0 k − 1 x [ i + m , j + n ] ⋅ w [ m , n ] + b y[i, j] = \sum_{m=0}^{k-1} \sum_{n=0}^{k-1} x[i+m, j+n] \cdot w[m, n] + b y[i,j]=m=0k1n=0k1x[i+m,j+n]w[m,n]+b

  • (x):输入图像
  • (w):卷积核权重
  • (b):偏置项
  • (k):卷积核大小

卷积操作示意图:

在这里插入图片描述
在这里插入图片描述


2. 激活函数(Activation Function)

激活函数引入非线性因素,常用的是 ReLU 函数:

f ( x ) = max ⁡ ( 0 , x ) f(x) = \max(0, x) f(x)=max(0,x)

ReLU 函数示意图:

在这里插入图片描述


3. 池化层(Pooling Layer)

池化层通过降维保留关键信息,常用的是最大池化(Max Pooling):

y [ i , j ] = max ⁡ m , n ∈ window x [ i + m , j + n ] y[i, j] = \max_{m, n \in \text{window}} x[i+m, j+n] y[i,j]=m,nwindowmaxx[i+m,j+n]

池化操作示意图:

在这里插入图片描述


LeNet-5 的结构

本文实现了经典的 LeNet-5 模型,用于 MNIST 手写数字分类。

LeNet-5 的结构如下:

层类型输入大小卷积核大小输出大小
输入层1 x 28 x 28-1 x 28 x 28
卷积层 11 x 28 x 285 x 56 x 24 x 24
池化层 16 x 24 x 242 x 26 x 12 x 12
卷积层 26 x 12 x 125 x 516 x 8 x 8
池化层 216 x 8 x 82 x 216 x 4 x 4
全连接层 1256-120
全连接层 2120-84
输出层84-10

示意图:
在这里插入图片描述


PyTorch 实现

数据加载

使用 torchvision.datasets 下载 MNIST 数据集,数据经过归一化和转换为张量。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

LeNet-5 模型

LeNet-5 使用两层卷积、两层池化和三层全连接。

import torch
from torch.nn import Module
from torch import nn

class LeNet5(Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(256, 120),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

模型训练

模型使用交叉熵损失函数和 Adam 优化器。

import torch.optim as optim
from torch import nn

model = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    model.train()
    for batch_idx, (data, label) in enumerate(train_loader):
        output = model(data)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}/10 | Batch {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

torch.save(model.state_dict(), 'mnist_lenet5.pth')
print("模型训练完成,已保存至 'mnist_lenet5.pth'")

模型验证

训练完成后,验证模型的准确率。

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data, label in test_loader:
        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

accuracy = correct / total
print(f"模型在测试集上的准确率: {accuracy * 100:.2f}%")

结果分析

训练模型的损失值逐渐下降,最终测试集准确率为 99.2%,表明模型在手写数字分类任务上的效果非常好。

EpochBatchLoss
11002.3096
55000.0701
109000.0268

在这里插入图片描述


总结

本文从原理、实现到结果分析详细介绍了卷积神经网络和 LeNet-5 模型。关键点包括:

  1. 卷积、激活函数和池化操作 是 CNN 的核心。
  2. LeNet-5 是 CNN 的经典结构,适用于简单任务。
  3. PyTorch 提供了强大的工具链,帮助快速构建和训练模型。

参考文献

  • LeNet-5 Paper
  • PyTorch Documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2250343.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

A-star算法

算法简介 A*(A-star)算法是一种用于图形搜索和路径规划的启发式搜索算法,它结合了最佳优先搜索(Best-First Search)和Dijkstra算法的思想,能够有效地寻找从起点到目标点的最短路径。A*算法广泛应用于导航、…

【数据集划分】训练集train/验证集val/测试集test是如何划分的?

🚀在跑代码时常常将数据集简单的划分为训练集train和测试集test(二划分),其实更为全面完整的划分应该是划分为训练集train、验证集val、测试集test(三划分)。那么具体如何划分呢?各个部分起着什么作用呢? 如下图所示,…

Gentoo Linux部署LNMP

一、安装nginx 1.gentoo-chxf ~ # emerge -av nginx 提示配置文件需更新 2.gentoo-chxf ~ # etc-update 3.gentoo-chxf ~ # emerge -av nginx 4.查看并启动nginx gentoo-chxf ~ # systemctl status nginx gentoo-chxf ~ # systemctl start nginx gentoo-chxf ~ # syst…

Ubantu系统非root用户安装docker教程

非root用户没有超级权限,根据docker安装教程安装完毕会发现无法拉取镜像,或者每次运行docker都需要加上sudo,输入密码验证。 解决办法如下: 1、创建docker用户组 sudo groupadd docker2、将非root用户(当前用户&am…

python可视化高纬度特征

可视化网络的特征层,假如resnet网络输出的特征维度是(batch_size,512). 如果要可视化测试集的每个图片的512高维度特征分布呢? embeds resnet18(x),embeds是(batch_size,512)高维度特征。如下可视化。 import torch import matplotlib.pyp…

OceanBase 大数据量导入(obloader)

现需要将源数据库(Oracle|MySQL等)一些表的海量数据迁移到目标数据库 OceanBase 中,基于常规 jdbc 驱动编码的方式涉及开发工作,性能效率也要看编码的处理机制。 OceanBase 官方提供了的 OceanBase Migration Service (OMS) 数据…

Mac启动服务慢问题解决,InetAddress.getLocalHost().getHostAddress()慢问题。

项目启动5分钟,很明显有问题。像网上其他的提高jvm参数就不说了,应该不是这个问题,也就快一点。 首先找到自己的电脑名称(用命令行也行,只要能找到自己电脑名称就行,这里直接在共享里看)。 复制…

Ubuntu交叉编译 opencv for QNX

前言 在高通板子上开发一些程序的时候,会用到opencv帮助处理一下图像数据,高通车载板子sa8155和sm8295都有QNX os,需要交叉编译opencv的库,(这个交叉编译真是搞得我太恶心了,所以进行一个记录和分享) 搜了很多资料,有些太过于复杂,有些也存在错误导致最后没有编译成…

.NET 9 AOT的突破 - 支持老旧Win7与XP环境

引言 随着技术的不断进步,微软的.NET 框架在每次迭代中都带来了令人惊喜的新特性。在.NET 9 版本中,一个特别引人注目的亮点是 AOT( Ahead-of-Time)支持,它允许开发人员将应用程序在编译阶段就优化为能够在老旧的 Win…

Mac 环境下类Xshell 的客户端介绍

在 Mac 环境下,类似于 Windows 环境中 Xshell 用于访问 Linux 服务器的工具主要有以下几种: SecureCRT: 官网地址:https://www.vandyke.com/products/securecrt/介绍:支持多种协议,如 SSH1、SSH2、Telnet 等…

Cookie跨域

跨域:跨域名(IP) 跨域的目的是共享Cookie。 session操作http协议,每次既要request,也要response,cookie在创建的时候会产生一个字符串然后随着response返回。 全网站的各个页面都会带着登陆的时候的cookie …

虚拟机CentOS系统通过Docker部署RSSHub并映射到主机

公告 📌更新公告 20241124-该文章已同步更新到作者的个人博客(链接:虚拟机CentOS系统通过Docker部署RSSHub并映射到主机) 一、编辑 YUM 配置文件 1、打开 CentOS 系统中的 YUM 软件仓库配置文件 vim /etc/yum.repos.d/CentOS-Ba…

DreamCamera2相机预览变形的处理

最近遇到一个问题,相机更换了摄像头后,发现人像角度顺时针旋转了90度,待人像角度正常后,发现 预览时图像有挤压变形,最终解决。在此记录 一人像角度的修改 先放示意图 设备预览人像角度如图1所示,顺时针旋…

Taro React小程序开发框架 总结

目录 一、安装 二、目录结构 三、创建一个自定义页面 四、路由 1、API 2、传参 3、获取路由参数 4、设置TabBar 五、组件 六、API Taro非常好用的小程序框架,React开发者无缝衔接上。 一、安装 官方文档:Taro 文档 注意,项目创建…

RPA:电商订单处理自动化

哈喽,大家好,我是若木,最近闲暇时间较多,于是便跟着教程做了一个及RPA,谈到这个,可能很多人并不是很了解,但是实际上,这玩意却遍布文末生活的边边角角。话不多说,我直接上…

通过金蝶云星空实现高效仓储管理

金蝶云星空数据集成到旺店通WMS的技术案例分享 在企业日常运营中,库存管理和物流调度是至关重要的环节。为了实现高效的数据流转和业务协同,我们采用了轻易云数据集成平台,将金蝶云星空的数据无缝对接到旺店通WMS。本次案例聚焦于“调拨入库…

go结构体匿名“继承“方法冲突时继承优先顺序

在 Go 语言中,匿名字段(也称为嵌入字段)可以用来实现继承的效果。当你在一个结构体中匿名嵌入另一个结构体时,嵌入结构体的方法会被提升到外部结构体中。这意味着你可以直接通过外部结构体调用嵌入结构体的方法。 如果多个嵌入结…

丹摩|丹摩智算平台使用教学指南

本指南旨在为新用户提供一个详细的操作步骤和实用的入门指导,帮助大家快速上手丹摩智算平台。 一、平台简介 丹摩智算平台是一款强大的数据分析和计算平台,支持多种编程语言,提供丰富的数据处理和机器学习工具。无论您是数据分析师、开发者…

从网桥到交换机:技术演变与应用场景

交换机和网桥是网络基础设施中不可或缺的设备,它们都用于提升网络性能和连接网络节点。然而,两者在设计目的、功能范围和适用场景上存在诸多不同之处。本文将从功能、差异和相互关系的角度,探讨交换机与网桥在网络中的角色。 交换机的功能与特…

ollama部署bge-m3,并实现与dify平台对接

概述 这几天为了写技术博客,各种组件可谓是装了卸,卸了装,只想复现一些东西,确保你们看到的东西都是可以复现的。 (看在我这么认真的份上,求个关注啊,拜托各位观众老爷了。) 这不,为了实验在windows上docker里运行pytorch,把docker重装了。 dify也得重装: Dify基…