12.26 学习卷积神经网路(CNN)

news2025/2/13 11:35:26

完全是基于下面这个博客来进行学习的,感谢!

​​【深度学习基础】详解Pytorch搭建CNN卷积神经网络LeNet-5实现手写数字识别_pytorch cnn-CSDN博客

 基于深度神经网络DNN实现的手写数字识别,将灰度图像转换后的二维数组展平到一维,将一维的784个特征作为模型输入。在“展平”的过程中必然会失去一些图像的形状结构特征,因此基于DNN的实现方式并不能很好的利用图像的二维结构特征,而卷积神经网络CNN对于处理图像的位置信息具有一定的优势。因此卷积神经网络经常被用于图像识别/处理领域。

1、卷积层

内参数(卷积核本身)  

CNN中的卷积层和DNN中的全连接层是平级关系,在DNN中,我们训练的内参数是全连接层的权重w和偏置b,CNN也类似,CNN训练的是卷积核,也就相当于包含了权重和偏置两个内部参数。

 输入一个多维数据(上图为二维),与卷积核进行运算,即输入中与卷积核形状相同的部分,分别与卷积核进行逐个元素相乘再相加。例如计算结果中坐上角的15是根据如下过程计算得到的:

 

逐个元素相乘再相加,即:

1 * 2 + 2 * 0 + 3* 1 + 0 * 0 + 1 * 1 + 2 * 2 + 3 * 1 + 0 * 0 + 1 * 2 = 15 

外参数(填充和步幅)

   填充(padding)

   显然,只要卷积核的大小>1*1,必然会导致图像越卷越小,为了防止输入经过多个卷积层后变得过小,可以在发生卷积层之前,先向输入图像的外围填入固定的数据(比如0),这个步骤称之为填充,如下图:

 在使用Pytorch搭建卷积层的时候,需要在对应的接口中添加这个padding参数,向上图中这种情况,相当于在3*3的卷积核外围添加了“一圈”,则padding = 1,卷积层的接口中就要这样写:

nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, paddding=1)

      参数in_channels和out_channels是对应于这个卷积层输入和输出的通道数参数,这里我们先放一放。

步幅(stride)

   步幅指的是使用卷积核的位置间隔,即输入中参与运算的那个范围每次移动的距离。前面几个示意图中的步幅均为1,即每次移动一格,如果设置stride=2,kernel_size=2,则效果如下:

     此时需要在卷积层接口中添加参数stride=2。

输入与尺寸的关系

 综上所述,结合外参数(步幅、填充)和内参数(卷积核),可以看出如下规律:

卷积核越大,输出越小。

步幅越大,输出越小。

填充越大,输出越大。

     用公式表示定量关系:

      如果输入和卷积核均为方阵,设输入尺寸为W*W,输出尺寸为N*N,卷积核尺寸为F*F,填充的圈数为P,步幅为S,则有关系:

N=\frac{W+2P-F}{S}+1

    这个关系大家要重点掌握,也可以自己推导一下,并不复杂。如果输入和卷积核不为方阵,设输入尺寸是H*W,输出尺寸是OH*OW,卷积核尺寸为FH*FW,填充为P,步幅为S,则输出尺寸OH*OW的计算公式是:

OH=\frac{H+2P-FH}{S}+1

OW=\frac{W+2P-FW}{S}+1

2、多通道问题

多通道输入

对于手写数字识别这种灰度图像,可以视为仅有(高*长)二维的输入。然而,对于彩色图像,每一个像素点都相当于是RGB的三个值的组合,因此对于彩色的图像输入,除了高*长两个维度外,还有第三个维度——通道,即红、绿、蓝三个通道,也可以视为3个单通道的二维图像的混合叠加。

当输入数据仅为二维时,卷积层的权重往往被称作卷积核(Kernel);

当输入数据为三维或更高时,卷积层的权重往往被称作滤波器(Filter)。

     对于多通道输入,输入数据和滤波器的通道数必须保持一致。这样会导致输出结果降维成二维,如下图:

对形状进行一下抽象,则输入数据C*H*W和滤波器C*FH*FW都是长方体,结果是一个长方形1*OH*OW,注意C,H,W是固定的顺序,通道数要写在最前。

多通道输出

 如果要实现多通道输出,那么就需要多个滤波器,让三维输入与多个滤波器进行卷积,就可以实现多通道输出,输出的通道数FN就是滤波器的个数FN,如下图:

 和单通道一样,卷积运算后也有偏置,如果进一步追加偏置,则结果如下:每个通道都有一个单独的偏置

3、池化层

池化,也叫汇聚(Pooling)。池化层通常位于卷积层之后(有时也可以不设置池化层),其作用仅仅是在一定范围内提取特征值,所以并不存在要学习的内部参数。池化仅仅对图像的高H和宽W进行特征提取,并不改变通道数C

一般有平均汇聚和最大值汇聚两种。

平均汇聚

如上图,池化的窗口大小为2*2,对应的步幅为2,因此对于上图这种情况,对应的Pytorch接口如下:

nn.AvgPool2d(kernel_size=2, stride=2) 

最大值汇聚

对应的Pytorch接口如下:

nn.MaxPool2d(kernel_size=2, stride=2) 

4、手写数字识别

详细内容见,我自己就在jupyter上写写代码【深度学习基础】详解Pytorch搭建CNN卷积神经网络LeNet-5实现手写数字识别_pytorch cnn-CSDN博客

任务描述

输入就相当于一个单通道的图像,是二维的。我们在实现的时候,要将每个样本图像转换为28*28的张量,作为输入

 因此对于整个手写数字识别的任务,模型的输入是一副图像,输出则是一个对应的识别出的数字(0-9之间的整数)。这里注意,在进行模型训练时,PyTorch会在整个过程中自动将输出转换为One-Hot编码,因此我们在训练时不需要手动将输出转换为One-Hot编码,但进行模型评估测试时,由于需要比对预测输出(One-hot)和真实输出(0-9的数字),要进行一次转化。

    对于单个样本,每个图像都是一个二维灰度图像,像素为28*28.二维灰度图像的通道数为1,因此可以将每个样本图像转换为28*28的张量,作为输入。相当于是下图这样的形式:

    具体怎么转换,需要用到torchvision库中的trabsforms进行图像转换,将数据集转换为张量的形式,并调整数据集的统计分布(转换为标准正态分布更利于训练)。

网络结构(LeNet-5)

LeNet-5起源于1998年,在手写数字识别上非常成功。其结构如下:

再列一个表格,具体结构如下:

注:输出层的激活函数目前已经被Softmax取代。

至于这些尺寸关系,我举个两例子吧:

   以第一层C1的输入和输出为例。输入尺寸W是28*28,卷积核F尺寸为5*5,步幅S为1,填充P为2,那么输出N的28*28怎么来的呢?按照公式如下:

N=\frac{W+2P-F}{S}+1=\frac{28+2*2-5}{1}+1=28

   我们也可以观察到第一层的卷积核个数为6,则输出的通道数也为6。

  再看一下第一个池化层S2,输入尺寸是28*28,卷积核F大小为2*2(此处的“卷积核”实际上指的是采样范围),步幅S=2,填充P为0,则输出的14*14是这么算出来的:

N=\frac{W+2P-F}{S}+1=\frac{28+2*0-2}{2}+1=14

import torch
from torch import nn
from net import MyLeNet5
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os
 
# 将图像转换为张量形式
data_transform = transforms.Compose([
    transforms.ToTensor()
])
 
# 加载训练数据集
train_dataset = datasets.MNIST(
    root='D:\\Jupyter\\dataset\\minst',  # 下载路径
    train=True,   # 是训练集
    download=True,   # 如果该路径没有该数据集,则进行下载
    transform=data_transform   # 数据集转换参数
)
 
# 批次加载器
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
 
# 加载测试数据集
test_dataset = datasets.MNIST(
    root='D:\\Jupyter\\dataset\\minst',  # 下载路径
    train=False,   # 是训练集
    download=True,   # 如果该路径没有该数据集,则进行下载
    transform=data_transform   # 数据集转换参数
)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)
 
# 判断是否有gpu
device = "cuda" if torch.cuda.is_available() else "cpu"
 
# 调用net,将模型数据转移到gpu
model = MyLeNet5().to(device)
 
# 选择损失函数
loss_fn = nn.CrossEntropyLoss()    # 交叉熵损失函数,自带Softmax激活函数
 
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
 
# 学习率每隔10轮次, 变为原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
 
 
# 定于训练函数
def train(dataloader, model, loss_fn, optimizer):
    loss, current, n = 0.0, 0.0, 0
    for batch, (X, y) in enumerate(dataloader):
        # 前向传播
        X, y = X.to(device), y.to(device)
        output = model(X)
        cur_loss = loss_fn(output, y)
        _, pred = torch.max(output, dim=1)
        # 计算当前轮次时,训练集的精确度
        cur_acc = torch.sum(y == pred)/output.shape[0]
 
        # 反向传播
        optimizer.zero_grad()
        cur_loss.backward()
        optimizer.step()
 
        loss += cur_loss.item()
        current += cur_acc.item()
        n = n + 1
    print("train_loss: ", str(loss/n))
    print("train_acc: ", str(current/n))
 
 
def test(dataloader, model, loss_fn):
    model.eval()
    loss, current, n = 0.0, 0.0, 0
    # 该局部关闭梯度计算功能,提高运算效率
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            # 前向传播
            X, y = X.to(device), y.to(device)
            output = model(X)
            cur_loss = loss_fn(output, y)
            _, pred = torch.max(output, dim=1)
            # 计算当前轮次时,训练集的精确度
            cur_acc = torch.sum(y == pred) / output.shape[0]
            loss += cur_loss.item()
            current += cur_acc.item()
            n = n + 1
        print("test_loss: ", str(loss / n))
        print("test_acc: ", str(current / n))
        return current/n    # 返回精确度
 
 
# 开始训练
epoch = 50
max_acc = 0
for t in range(epoch):
    print(f"epoch{t+1}\n---------------")
    train(train_dataloader, model, loss_fn, optimizer)
    a = test(test_dataloader, model, loss_fn)
    # 保存最好的模型参数
    if a > max_acc:
        folder = 'save_model'
        if not os.path.exists(folder):
            os.mkdir(folder)
        max_acc = a
        print("current best model acc = ", a)
        torch.save(model.state_dict(), 'save_model/best_model.pth')
print("Done!")
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2266186.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【团标】《信息工程造价政务信息化项目造价评估方法》(TCQAE11021-2023)-费用标准解读系列33

《信息工程造价政务信息化项目造价评估方法》(TCQAE11021-2023)是中国电子质量管理协会2023年发布,2023年12月16日开始实施的标准(了解更多可直接关注我们咨询)。该标准适用于政务信息化项目的造价评估,政务…

mybatisplu设置自动填充

/*** mybatisplus的自动化填充*/public class JboltMetaObjectHandler implements MetaObjectHandler {Overridepublic void insertFill(MetaObject metaObject) {LocalDateTime now LocalDateTime.now(ZoneId.of("Asia/Shanghai"));this.strictInsertFill(metaObje…

音视频入门基础:MPEG2-TS专题(23)——通过FFprobe显示TS流每个packet的信息

一、引言 通过FFprobe命令: ffprobe -of json -show_packets XXX.ts 可以显示TS流/TS文件每个packet(也称为数据包或多媒体包)的信息: 对于TS流,上述的“packet”(数据包或多媒体包)是指&…

Linux电源管理——CPU Hotplug 流程

目录 一、相关概念 二、基本原理 三、代码分析 1、CPU_ON 2、CPU_OFF References Linux Version:linux-5.4.239 一、相关概念 在单核操作系统中,操作系统只需管理一个CPU,当系统有任务需要执行时,所有的任务会在该CPU的就绪…

探索数据的艺术:R语言与Origin的完美结合

探索数据的艺术:R语言与Origin的完美结合 R语言统计分析与可视化从入门到精通内容简介获取方式 Origin绘图深度解析:科研数据的可视化艺术内容简介获取方式 R语言统计分析与可视化从入门到精通 内容简介 本书循序渐进、深入讲解了R语言数据统计分析与应…

python基础训练之元组的基本操作

主页包含元组基础知识点 【练习要求】 针对于元组的知识点进行常用的创建、定义、查询元素、查看元组长度等操作。效果实现如下 (注:特别要注意一下切片的用法) #创建元组的两种方法 T1 () T2 tuple() #定义一个元组并存储数据张三, 李四, 王五 T3 (张三, 李四…

选煤厂可视化技术助力智能化运营

通过图扑 HT 可视化搭建智慧选煤厂管理平台,优化了选煤生产流程,提高了资源利用率和安全性,助力企业实现智能化运营和可持续发展目标。

C语言基础:指针(数组指针与指针数组)

数组指针与指针数组 数组指针 概念:数组指针是指向数组的指针,本质上还是指针 特点: 先有数组,后有指针 它指向的是一个完整的数组 一维数组指针: 语法: 数据类型 (*指针变量名)[行容量][列容量]; 案…

接口测试Day03-postman断言关联

postman常用断言 注意:不需要手敲,点击自动生成 断言响应状态码 Status code:Code is 200 //断言响应状态码为 200 pm.test("Status code is 200", function () {pm.response.to.have.status(200); });pm: postman的实例 test() …

01- 三自由度串联机械臂位置分析

三自由度串联机械臂如下图所示(d180mm,L1100mm,L280mm),利用改进DH法建模,坐标系如下所示: 利用改进DH法建模,该机器人的DH参数表如下所示: 对该机械臂进行位置分析&…

lxml 解析xml\html

from lxml import etree# XML文档示例 xml_doc """ <root><book><title>Python编程指南</title><author>张三</author></book><book><title>Python高级编程</title><author>李四</autho…

用Python写炸金花游戏

文章目录 **代码分解与讲解**1. **扑克牌的生成与洗牌**2. **给玩家发牌**3. **打印玩家的手牌**4. **定义牌的优先级**5. **判断牌型**6. **确定牌型优先级**7. **比较两手牌的大小**8. **计算每个玩家的牌型并找出赢家**9. **打印结果** 完整代码 以下游戏规则&#xff1a; 那…

基于 SpringBoot微信小程序的医院预约挂号系统

摘 要 时代在飞速进步&#xff0c;每个行业都在努力发展现在先进技术&#xff0c;通过这些先进的技术来提高自己的水平和优势&#xff0c;医院预约挂号系统当然不能排除在外。医院预约挂号系统是在实际应用和软件工程的开发原理之上&#xff0c;运用微信开发者、java语言以及…

高仿CSDN编辑器,前端博客模板

高仿CSDN编辑器纯前端模板&#xff0c;使用的js、html、vue、axios等技术&#xff0c;网络请求库已进行封装&#xff0c;可以按需调整界面,需要源码联系(4k左右)。 1.支持代码高亮 2.支持目录点击定位 3.支持文件上传、图片上传&#xff08;需要自己写后端接口&#xff09; 4.M…

【C++11】类型分类、引用折叠、完美转发

目录 一、类型分类 二、引用折叠 三、完美转发 一、类型分类 C11以后&#xff0c;进一步对类型进行了划分&#xff0c;右值被划分纯右值(pure value&#xff0c;简称prvalue)和将亡值 (expiring value&#xff0c;简称xvalue)。 纯右值是指那些字面值常量或求值结果相当于…

在线oj项目 Ubuntu安装vue/cil(vue脚手架)

参考:https://blog.csdn.net/weixin_66062303/article/details/129046198 笔记 参考:https://blog.csdn.net/m0_74352571/article/details/144076227 https://cli.vuejs.org/zh/guide/installation.html 确保nodejs已经安装 npm换源淘宝镜像&#xff08;可以不操作或者使用魔…

Python字符串及正则表达式(十一):正则表达式、使用re模块实现正则表达式操作

前言&#xff1a;在 Python 编程的广阔天地中&#xff0c;字符串处理无疑是一项基础而关键的技能。正则表达式&#xff0c;作为处理字符串的强大工具&#xff0c;以其灵活的模式匹配能力&#xff0c;在文本搜索、数据清洗、格式验证等领域发挥着不可替代的作用。本系列博客已经…

项目37:简易个人健身记录器 --- 《跟着小王学Python·新手》

项目37&#xff1a;简易个人健身记录器 — 《跟着小王学Python新手》 《跟着小王学Python》 是一套精心设计的Python学习教程&#xff0c;适合各个层次的学习者。本教程从基础语法入手&#xff0c;逐步深入到高级应用&#xff0c;以实例驱动的方式&#xff0c;帮助学习者逐步掌…

华为:数字化转型只有“起点”,没有“终点”

上个月&#xff0c;我收到了一位朋友的私信&#xff0c;他询问我是否有关于华为数字化转型的资料。幸运的是&#xff0c;我手头正好收藏了一些&#xff0c;于是我便分享给他。 然后在昨天&#xff0c;他又再次联系我&#xff0c;并感慨&#xff1a;“如果当初我在进行企业数字…

count(1)、count(_)与count(列名)的区别?

大家好&#xff0c;我是锋哥。今天分享关于【count(1)、count(_)与count(列名)的区别&#xff1f;】面试题。希望对大家有帮助&#xff1b; count(1)、count(_)与count(列名)的区别&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在 SQL 中&#xff0c…