深度学习---卷积神经网络

news2024/9/22 11:29:01

卷积神经网络概述

卷积神经网络是深度学习在计算机视觉领域的突破性成果。在计算机视觉领域。往往输入的图像都很大,使用全连接网络的话,计算的代价较高。另外图像也很难保留原有的特征,导致图像处理的准确率不高。

卷积神经网络(Convolutional Neural Network)是含有卷积层的神经网络。卷积层的作用就是用来自动学习、提取图像的特征。

CNN网络主要有三部分构成:卷积层、池化层和全连接层构成,其中卷积层负责提取图像中的局部特征;池化层用来大幅降低参数量级(降维);全连接层类似人工神经网络的部分,用来输出想要的结果。

图像概述

图像是由像素点组成的,每个像素点的值范围为:[0,255],像素值越大意味着较亮。比如一张 200x200 的图像,则是由 40000 个像素点组成,如果每个像素点都是 0 的话,意味着这是一张全黑的图像。

彩色图一般都是多通道的图像,所谓多通道可以理解为图像由多个不同的图像层叠加而成,例如看到的彩色图像一般都是由 RGB 三个通道组成的,还有一些图像具有 RGBA 四个通道,最后一个通道为透明通道,该值越小,则图像越透明。

卷积层

卷积计算

Padding

Stride

多通道卷积计算

多卷积核卷积计算

特征图大小

池化层

池化层 (Pooling) 降低维度,缩减模型大小,提高计算速度。即:主要对卷积层学习到的特征图进行下采样(SubSampling)处理。池化层主要有两种:最大池化、平均池化。

池化层计算

Stride

Padding

多通道池化计算

案例-图像分类

CIFAR10 数据集

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


# 1. 数据集基本信息
def test01():
    # 加载数据集
    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))

    # 数据集数量
    print('训练集数量:', len(train.targets))
    print('测试集数量:', len(valid.targets))

    # 数据集形状
    print("数据集形状:", train[0][0].shape)

    # 数据集类别
    print("数据集类别:", train.class_to_idx)


# 2. 数据加载器
def test02():
    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    dataloader = DataLoader(train, batch_size=8, shuffle=True)
    for x, y in dataloader:
        print(x.shape)
        print(y)
        break


if __name__ == '__main__':
    test01()
    test02()

搭建图像分类网络

class ImageClassification(nn.Module):

    def __init__(self):
        super(ImageClassification, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        # 由于最后一个批次可能不够 32,所以需要根据批次数量来 flatten
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))

        return self.out(x)

编写训练函数

使用多分类交叉熵损失函数,Adam 优化器:

def train():
    # 加载 CIFAR10 训练集, 并将其转换为张量
    transgform = Compose([ToTensor()])
    cifar10 = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transgform)

    # 构建图像分类模型
    model = ImageClassification()
    # 构建损失函数
    criterion = nn.CrossEntropyLoss()
    # 构建优化方法
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    # 训练轮数
    epoch = 100

    for epoch_idx in range(epoch):

        # 构建数据加载器
        dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)
        # 样本数量
        sam_num = 0
        # 损失总和
        total_loss = 0.0
        # 开始时间
        start = time.time()
        correct = 0

        for x, y in dataloader:
            # 送入模型
            output = model(x)
            # 计算损失
            loss = criterion(output, y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            correct += (torch.argmax(output, dim=-1) == y).sum()
            total_loss += (loss.item() * len(y))
            sam_num += len(y)

        print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %
              (epoch_idx + 1,
               total_loss / sam_num,
               correct / sam_num,
               time.time() - start))

    # 序列化模型
    torch.save(model.state_dict(), 'model/image_classification.bin')

编写预测函数

def test():
    # 加载 CIFAR10 测试集, 并将其转换为张量
    transgform = Compose([ToTensor()])
    cifar10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transgform)
    # 构建数据加载器
    dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)
    # 加载模型
    model = ImageClassification()
    model.load_state_dict(torch.load('model/image_classification.bin'))
    model.eval()

    total_correct = 0
    total_samples = 0
    for x, y in dataloader:
        output = model(x)
        total_correct += (torch.argmax(output, dim=-1) == y).sum()
        total_samples += len(y)

    print('Acc: %.2f' % (total_correct / total_samples))

总结

可以从以下几个方面来调整网络:

  1. 增加卷积核输出通道数
  2. 增加全连接层的参数量
  3. 调整学习率
  4. 调整优化方法
  5. 修改激活函数
  6. 等等...

把学习率由 1e-3 修改为 1e-4,并网络参数量增加如下代码所示:

class ImageClassification(nn.Module):

    def __init__(self):
        super(ImageClassification, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 128, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.linear1 = nn.Linear(128 * 6 * 6, 2048)
        self.linear2 = nn.Linear(2048, 2048)
        self.out = nn.Linear(2048, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        # 由于最后一个批次可能不够 32,所以需要根据批次数量来 flatten
        x = x.reshape(x.size(0), -1)

        x = F.relu(self.linear1(x))
        x = F.dropout(x, p=0.5)

        x = F.relu(self.linear2(x))
        x = F.dropout(x, p=0.5)

        return self.out(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1122359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安全多方计算框架最全合集(持续更新)

安全多方计算框架 本文对现有安全多方计算/学习框架进行了全面、系统的梳理。 目前大部分安全多方计算框架主要基于秘密共享、同态加密、混淆电路以及相关基本模块的组合。通常使用定制的协议来支持特定数量的参与方(一般为两方或三方),导致…

我发现用StarUML来画UML也挺香的

2023年10月22日,周日晚上 我已经决定以后都用StarUML来画UML了,因为这个软件不仅免费,而且太适合画UML图了 我之前主要用plantUML和draw.io来画UML, plantUML虽然是通过文本来生成UML图,但是排版不是不好看&#xff…

“智能文件批量改名工具:轻松管理文件名,一键去除特殊符号“

你是否曾经在面对一堆文件名中包含特殊符号,而感到困扰,不知道如何快速、准确地处理它们?现在,我们为你带来了一款智能文件批量改名工具,它可以轻松地帮助你去除文件名中的特殊符号,让你的文件管理更加规范…

【试题002】C语言有关于sizeof的使用

1.说明&#xff1a;sizeof()是测量数据类型所占用的内存字节数&#xff0c;字符串常量在存储时除了要存储有效字节外&#xff0c;还要存储一个字符串结束志‘\0’。 2.代码举栗子&#xff1a; #include <stdio.h> int main() {char str[] "book";printf(&qu…

网络协议--Traceroute程序

8.1 引言 由Van Jacobson编写的Traceroute程序是一个能更深入探索TCP/IP协议的方便可用的工具。尽管不能保证从源端发往目的端的两份连续的IP数据报具有相同的路由&#xff0c;但是大多数情况下是这样的。Traceroute程序可以让我们看到IP数据报从一台主机传到另一台主机所经过…

开源LLEMMA发布:超越未公开的顶尖模型,可直接应用于工具和定理证明

深度学习自然语言处理 原创作者&#xff1a;Winnie 今天向大家介绍一个新的开源大语言模型——LLEMMA&#xff0c;这是一个专为数学研究而设计的前沿语言模型。 LLEMMA解数学题的一个示例 LLEMMA的诞生源于在Proof-Pile-2数据集上对Code Llama模型的深度训练。这个数据集是一个…

Java 8 新特性 Ⅱ

方法引用 举例: Integer :: compare 理解: 可以看作是基于lambda表达式的进一步简化 当需要提供一个函数式接口的实例时, 可以使用lambda表达式提供实例 当满足一定条件下, 可以使用方法引用or构造器引用替换lambda表达式 实质: 方法引用作为函数式接口的实例 (注: 需要熟悉…

【AOA-VMD-LSTM分类故障诊断】基于阿基米德算法AOA优化变分模态分解VMD的长短期记忆网络LSTM分类算法(Matlab代码)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

人大金仓获评“2023年度软件和信息技术服务名牌企业”

近日&#xff0c;中国电子信息行业联合会在2023世界数字经济大会暨第十三届智慧城市与智能经济博览会上&#xff0c;发布了“2023年度软件和信息技术服务名牌企业”&#xff0c;凭借在企业规模、技术创新、市场影响力等方面的突出表现&#xff0c;人大金仓成功入选。 此次评选在…

计组03:20min导图复习 中央处理器

&#x1f433;前言 图源&#xff1a;文心一言 考研笔记整理&#xff0c;纯复习向&#xff0c;思维导图基本就是全部内容了&#xff0c;不会涉及较深的知识点~~&#x1f95d;&#x1f95d; 第1版&#xff1a;查资料、画思维导图~&#x1f9e9;&#x1f9e9; 编辑&#xff1a;…

【数值分析】2 - 插值法

文章目录 一、引言1.1 插值法引入1.2 常用插值法1.3 插值法定义 二、插值法研究的问题2.1 插值多项式存在的唯一性2.2 如何构造n次多项式2.2.1 待定系数法2.2.2 拉格朗日插值法2.2.2.1 拉格朗日多项式2.2.2.2 拉格朗日插值余项2.2.2.3 例题2.2.2.4 拉格朗日插值法的问题 2.2.3 …

【大疆智图】大疆智图(DJI Terra 3.0.0)安装及使用教程

大疆智图是一款以二维正射影像与三维模型重建为主的软件,同时提供二维多光谱重建、激光雷达点云处理、精细化巡检等功能。它能够将无人机采集的数据可视化,实时生成高精度、高质量三维模型,满足事故现场、工程监测、电力巡线等场景的展示与精确测量需求。 文章目录 1. 安装D…

shell学习脚本04(小滴课堂)

他就可以直接读出来了。不需要在sh后面加参数。 可以用-s隐藏内容&#xff1a; 可以用-t进行指定几秒后显示。 -n限制内容长度。 输入到长度为5自动打印。 我们把-s放到-p后面的话&#xff1a; 这样会出错。 如果最后加5m会一直闪烁。 大家可以按照需求自行使用。

42907-2023 硅锭、硅块和硅片中非平衡载流子复合寿命的测试 非接触涡流感应法

1 范围 本文件描述了用非接触式涡流感应法测试太阳能电池用单晶硅锭、硅块和硅片中非平衡载流子复 合寿命的方法。 本文件适用于非平衡载流子复合寿命在0.1μs&#xff5e;10000 μs、电阻率在0.1 Ω cm&#xff5e;10000 Ω cm 的硅锭、硅块和硅片的测试。其中瞬态光电导衰…

Spring定时任务@Scheduled

在 Spring 框架中&#xff0c;可以使用定时任务来执行周期性或延迟执行的任务。Spring 提供了多种方式来配置和管理定时任务。有Java自带的java.util.Timer类&#xff0c;也有强大的调度器Quartz&#xff0c;还有SpringBoot自带的Scheduled。 在实际应用中&#xff0c;如果没有…

实际项目中如何进行问题排查

Linux自带 文本操作 文本查找 - grep文本分析 - awk文本处理 - sed文件操作 文件监听 - tail文件查找 - find网络和进程 网络接口 - ifconfig防火墙 - iptables -L路由表 - route -nnetstat其他常用 进程 ps -ef | grep java分区大小 df -h内存 free -m硬盘大小 fdisk -l | gr…

STM32F4_USB读卡器(USB_Slave)/USB U盘(Host)

前言 STM32F4芯片自带了USB OTG FS&#xff08;FS&#xff0c;即全速&#xff0c;12Mbps&#xff09;和USB OTG HS&#xff0c;支持USB Host和USB Device。 1. USB简介 USB&#xff0c;是英文Universal Serial BUS&#xff08;通用串行总线&#xff09;的缩写&#xff0c;是一…

硬盘无法分区的原因以及3种解决方法!

硬盘无法分区的原因 无论是新买的硬盘还是用了很久的硬盘&#xff0c;在分区过程中都可能会遇到硬盘无法分区的问题。在这里我们总结了以下几点原因&#xff1a; 主板生产商为了防止病毒侵入引导区文件在主板进行了设置&#xff0c;导致硬盘无法进行分区。 新买的硬盘没有…

42910-2023 无机胶粘剂高温压缩剪切强度试验方法

1 范围 本文件描述了在高温条件下测定无机胶粘剂压缩剪切强度的试验方法。 本文件适用于300℃~1000℃温度范围内&#xff0c;耐热陶瓷、复合材料及其他非金属材料之间搭接压缩 剪切强度的测定。 2 规范性引用文件 下列文件中的内容通过文中的规范性引用而构成本文件必不可…