【深度学习入门篇 ⑧】关于卷积神经网络

news2024/11/14 3:57:56

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


关于卷积神经网络,你还有哪些不知道的知识点呢,之前我们介绍了大部分,今天再来补充一下~

卷积神经网络基础

什么是卷积

Convolution,输入信息与核函数(滤波器)的乘积

  • 一维信号的时间卷积:输入x,核函数w,输出是一个连续时间段t的加权平均结果。
  • 二维图像的空间卷积:输入图像I,卷积核K,输出图像O。

单个二维图片卷积 :输入为单通道图像,输出为单通道图像。

图像的数据存储是多通道的二维矩阵:

灰度图(Gray)只有一个通道(一层),RGB彩色图就是三个通道(Red,Green,Blue),而RGBA彩色图就是四个通道(Red,Green,Blue,Alpha)。

如何表达每一个网络层中高维的图像数据?

特征图包含:通道,宽度,高度,其中输入特征图Ci,输出特征图C0,输出特征图的每一个通道,由输入图的所有通道和相同数量的卷积核先一一对应各自进行卷积计算,然后求和

 

卷积相关操作与参数

填充

padding :给卷积前的输入图像边界添加额外的行,列

  • 控 制 卷 积 后 图 像分 辨 率 , 方便计算特征图尺寸的变化
  • 弥 补 边 界 信 息 “丢 失 ” 

步长

步长(stride):卷积核在图像上移动的步子

卷积的核心思想 

为什么要进行局部连接?

  • 局部连接可以更好地利用图像中的结构信息,空间距离越相近的像素其相互影响越大

权重共享:保证不变性,图像从一个局部区域学习到的信息应用到其他区域 ,减少参数,降低学习难度。

ANN与CNN比较

传统神经网络为有监督的机器学习,输入为特征;卷积神经网络为无监督特征学习,输入为最原始的图像。


案例-图像分类

CIFAR10 数据集

CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32×32×3。

 PyTorch 中的 torchvision.datasets 计算机视觉模块封装了 CIFAR10 数据集:

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


def func1():

    # 加载数据集
    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))

    print('训练集数量:', len(train.targets))
    print('测试集数量:', len(valid.targets))

    print("数据集形状:", train[0][0].shape)

    print("数据集类别:", train.class_to_idx)


# 数据加载器
def func2():

    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    dataloader = DataLoader(train, batch_size=8, shuffle=True)
    for x, y in dataloader:
        print(x.shape)
        print(y)
        break


if __name__ == '__main__':
    func1()
    func2()

我们要搭建的网络结构:

我们在每个卷积计算之后应用 relu 激活函数来给网络增加非线性因素。

网络代码实现:

class ImageClassification(nn.Module):


    def __init__(self):

        super(ImageClassification, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)


    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = x.reshape(x.size(0), -1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))

        return self.out(x)

 编写训练函数

训练时,使用多分类交叉熵损失函数,Adam 优化器:

def train():

    transgform = Compose([ToTensor()])
    cifar10 = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transgform)


    model = ImageClassification()

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    # 训练轮数
    epoch = 100

    for epoch_idx in range(epoch):

        # 构建数据加载器
        dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)
        # 样本数量
        sam_num = 0
        # 损失总和
        total_loss = 0.0
        # 开始时间
        start = time.time()
        correct = 0

        for x, y in dataloader:
            # 送入模型
            output = model(x)
            # 计算损失
            loss = criterion(output, y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            correct += (torch.argmax(output, dim=-1) == y).sum()
            total_loss += (loss.item() * len(y))
            sam_num += len(y)

        print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %
              (epoch_idx + 1,
               total_loss / sam_num,
               correct / sam_num,
               time.time() - start))


    torch.save(model.state_dict(), 'model/image_classification.bin')

编写预测函数

我们加载训练好的模型,对测试集中的 1 万条样本进行预测,查看模型在测试集上的准确率

def test():


    transgform = Compose([ToTensor()])
    cifar10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transgform)
    # 构建数据加载器
    dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)
    # 加载模型
    model = ImageClassification()
    model.load_state_dict(torch.load('model/image_classification.bin'))
    model.eval()


    total_correct = 0
    total_samples = 0
    for x, y in  dataloader:
        output = model(x)
        total_correct += (torch.argmax(output, dim=-1) == y).sum()
        total_samples += len(y)

    print('Acc: %.2f' % (total_correct / total_samples))

输出:

'Acc: 0.61

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1932342.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

速部署 HBase 测试环境

快速部署 HBase 测试环境 第一步:下载软件,在HBase官网下载最新版, 找到 bin,点击下载,比如我这里下载的是 hbase-2.5.6-bin.tar.gz 第二步:解压软件 $ tar -zxvf hbase-2.5.6-bin.tar.gz $ cd hbase-2.…

完美解决ImportError: cannot import name ‘PILLOW_VERSION‘的正确解决方法,亲测有效!!!

完美解决ImportError: cannot import name PILLOW_VERSION’的正确解决方法,亲测有效!!! 亲测有效 完美解决ImportError: cannot import name PILLOW_VERSION的正确解决方法,亲测有效!!&#xf…

Java---多态

乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 欢迎大家访问~ 创作不易,大佬们点赞鼓励下吧~ 前言 前面博客了解了Java语法中继承的相关知识&…

泛微Ecology8明细表对主表赋值

文章目录 [toc]1.需求及效果1.1 需求1.2 效果2.思路与实现3.结语 1.需求及效果 1.1 需求 在明细表中的项目经理,可以将值赋值给主表中的项目经理来作为审批人员 1.2 效果 在申请人保存或者提交后将明细表中的人名赋值给主表中对应的值2.思路与实现 在通过js测…

MySQL数据库慢查询日志、SQL分析、数据库诊断

1 数据库调优维度 业务需求:勇敢地对不合理的需求说不系统架构:做架构设计的时候,应充分考虑业务的实际情况,考虑好数据库的各种选择(读写分离?高可用?实例个数?分库分表?用什么数据库?)SQL及索引:根据需求编写良…

鸿蒙语言基础类库:【@system.notification (通知消息)】

通知消息 说明: 从API Version 7 开始,该接口不再维护,推荐使用新接口[ohos.notification]。本模块首批接口从API version 3开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 import notification fro…

生信学院|07月19日《在线产品设计工具》

课程主题:在线产品设计工具 课程时间:2024年07月19日 14:00-14:30 主讲人:武旭 生信科技 售后服务工程师 基于云的设计平台高效的在线设计工具与SOLIDWORKS对比Q&A 安装腾讯会议客户端或APP,微信扫描海报中的二维码报名哦…

【光伏发电功率预测】方法综述学习笔记1

文章目录 研究背景为什么要做光伏发电功率预测?光伏功率预测难点影响光伏发电的因素光伏发电功率预测分类光伏发电功率预测方法预测评价指标总结 研究背景 近十年,化石能源消耗不断增加,环境污染日趋严重,已经成为国际社会普遍关…

连接池应用

一、什么是连接池: 当应用程序需要执行数据库操作时,它会从连接池中请求一个可用的连接。如果连接池中有空闲的连接,那么其中一个连接会被分配给请求者。一旦数据库操作完成,连接不会被关闭,而是被归还到连接池中&…

Seata的TCC模式与XA模式实战使用

文章目录 SeataXA模式整体机制微服务整合SeataXA SeataTCC模式什么是TCC以用户下单为例Seata TCC 模式Seata TCC模式接口改造TCC如何控制异常空回滚幂等悬挂 微服务整合SeataTCC 比较 SeataXA模式 XA协议最主要的作用是就是定义了RM-TM的交互接口,除此之外&#xf…

对LinkedList ,单链表和双链表的理解

一.ArrayList的缺陷 二.链表 三.链表部分相关oj面试题 四.LinkedList的模拟实现 五.LinkedList的使用 六.ArrayList和LinkedList的区别 一.ArrayList的缺陷: 1. ArrayList底层使用 数组 来存储元素,如果不熟悉可以来再看看: ArrayList与顺序表-CSDN…

zephyr BLE创建自定义服务

目录 LBS服务介绍实现过程 以创建LBS服务为例,在蓝牙标准里面没有这个服务,但是nordic有定制这个服务。 LBS服务介绍 实现过程 定义 GATT 服务及其特性的 128 位 UUID。包括服务UUID,特征的UUID。 #define BT_UUID_LBS_VAL BT_UUID_128_EN…

【BUG】已解决:ValueError: Expected 2D array, got 1D array instead

已解决:ValueError: Expected 2D array, got 1D array instead 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉…

“论软件维护方法及其应用”精选范文,软考高级论文,系统架构设计师论文

论文真题 软件维护是指在软件交付使用后,直至软件被淘汰的整个时间范围内,为了改正错误或满足 新的需求而修改软件的活动。在软件系统运行过程中,软件需要维护的原因是多种多样的, 根据维护的原因不同,可以将软件维护…

【Linux】线程——线程互斥的概念、锁的概念、互斥锁的使用、死锁、可重入和线程安全、线程同步、条件变量的概念和使用

文章目录 Linux线程4. 线程互斥4.1 线程互斥的概念4.2 锁的概念4.2.1 互斥锁的概念4.2.2 互斥锁的使用4.2.3 死锁4.2.4 可重入和线程安全 5. 线程同步5.1 条件变量的概念5.2 条件变量的使用 Linux线程 4. 线程互斥 我们之前使用了线程函数实现了多线程的简单计算模拟器。 可以…

3D问界—在MAYA中使用Python脚本进行批量轴居中

问题提出:MAYA中如何使用Python脚本 今天不是一篇纯理论,主要讲一下MAYA中如何使用Python脚本,并解决一个实际问题,文章会放上我自己的代码,若感兴趣欢迎尝试,当然,若有问题可以见文章末尾渠道&…

防火墙--带宽管理

目录 核心思想 带宽限制 带宽保证 连接数的限制 如何实现 接口带宽 队列调度 配置位置 在接口处配置 带宽策略配置位置 带宽通道 配置地方 接口带宽、带宽策略和带宽通道联系 配置顺序 带块通道在那里配置 选项解释 引用方式 策略独占 策略共享 重标记DSCP优先…

C# 中IEnumerable与IQuerable的区别

目的 详细理清IEnumerator、IEnumerable、IQuerable三个接口之间的联系与区别 继承关系:IEnumerator->IEnumerable->IQuerable IEnumerator:枚举器 包含了枚举器含有的方法,谁实现了IEnuemerator接口中的方法,就可以自定…

【坑】微信小程序开发wx.uploadFile和wx.request的返回值格式不同

微信小程序 使用wx.request,返回值是json,如下 {code:200,msg:"更新用户基本信息成功",data:[]} 因此可以直接使用如 res.data.code获取到返回值中的code字段 但是,上传图片需要使用wx.uploadFile,返回的结果如下 …

【知识图谱】【红楼梦】

参考链接 安装、使用教程(知乎):https://zhuanlan.zhihu.com/p/634006024Git :https://github.com/chizhu/KGQA_HLM 注:原项目为 【 重庆邮电大学,2018 林智敏 的毕业设计 】。【 感谢大佬的分享 】。 jav…