PyTorch实现卷积神经网络CNN

news2024/11/28 18:45:15

一、卷积神经网络CNN

二、代码实现(PyTorch)

1. 导入依赖库

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
  • nn:包含了torch已经准备好的层,激活函数、全连接层等

  • optim:提供了神经网络的一系列优化算法,如 SGD、Adam 等

  • datasets:提供常用的数据集,如 MNIST(本次使用)、CIFAR10/100、ImageNet、COCO 等

  • DataLoder:装载上面提到的数据集

2. 准备数据集

        这里使用MNIST数据集,它是一个大型手写数字数据库(包含0~9十个数字),原始的这两个数据集由128×128像素的黑白图像组成。LeCun等人将其进行归一化和尺寸调整后得到的是28×28的灰度图像。

        MNIST数据集总共包含两个子数据集:一个训练数据集(train_dataset)和一个测试数据集(test_dataset)。它们分别包含了60K和10K的28×28的灰度图像。代码如下:

# 训练集
train_dataset = datasets.MNIST(root='./',
                               train=True,
                               transform=transforms.ToTensor(),  # 数据转换为张量格式
                               download=True)
# 测试集
test_dataset = datasets.MNIST(root='./',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

batch_size = 100  # 批次大小
# 装载训练集
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,  # 每次加载多少条数据
                          shuffle=True)  # 生成数据前打乱数据 
# 装载测试集
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)

         这里值得注意的是,datasets.s=MNIST() 的参数 download 表示是否下载到参数 root 下的目录。但是实际使用过程中,从 https://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 下载会出现 403 forbidden 的报错信息。这个不必担心,torch 还会选择其他可用下载链接继续下载。 下载好的数据集应该有如下几个:

或者

3. 构建网络模型

        首先应该清楚,MNIST给到的原始训练集的图像可以表示为(batch_size, 1, 28, 28),其中 batch_size 代表一共加载了多少条数据,这里我之前设置了100;1代表这个训练集的图片是灰度图;两个28则为灰度图的长和宽。

        接下来就可以设计卷积层和池化层。

        设计卷积层时,应该注意第一层的卷积核数量(特征图数量)一般从较小的数值开始,我这里设置了32。因为灰度图的特征还算明显,因此卷积核可以适当减小,缓慢增加感受野,以此提高效率,因此设置为5×5。步长一般设置为1。至于填充几圈0,则可通过图像大小、卷积核大小、步长等推算得知。

        设计池化层时,首先确定池化法,这里选择最大池化法。选择最常用的2×2大小的池化核,它能够将特征图的宽和高减小一半。

        以下是每一层的详细设计思路:

  1. 卷积层1(conv1):先创建一个二维卷积层(Conv2d),然后确定激活函数(ReLU)对卷积层输出的每个值进行非线性变换,最后利用最大池化法(MaxPool)减小特征图尺寸防止过拟合。
  2. 卷积层2(conv2):由卷积层1的输出通道数确定卷积层2的输入通道数,其他不变。
  3. 全连接层1(fc1):使用 Dropout 来控制全连接层的过拟合问题,每次有50%的神经元不使用(只有训练状态下 Dropout 才起作用,测试状态下还是全部神经元工作)。在前向传播时需要注意,应该把卷积层的特征图维数修改为2维。
  4. 全连接层2(fc2):最后将1000个特征图输出为10个数字(0~9)的概率值。这里Softmax不加也行,因为后续在使用交叉熵代价函数(CrossEntropyLoss)时,因为它内部已经包括 Softmax 操作。
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1, 2),  # Conv2d(输入通道数(灰度图),输出通道数(生成多少特征图),卷积核大小(5×5),步长,0填充(填充2圈))
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # MaxPool2d(池化核大小2×2,步长为2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  
        )
        self.fc1 = nn.Sequential(
            nn.Linear(64 * 7 * 7, 1000),  # 将特征压缩为1000维的特征向量
            nn.Dropout(p=0.5),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1000, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.conv1(x)  # 特征图(batch_size, 1, 28, 28) -> (batch_size, 32, 14, 14)
        x = self.conv2(x)  # 特征图(batch_size, 32, 14, 14) -> (batch_size, 64, 7, 7)
        x = x.view(x.size()[0], -1)  # ([batch_size, 64, 7, 7]) -> (batch_size, 64*7*7)
        x = self.fc1(x)  # (batch_size, 64*7*7) -> (batch_size, 1000)
        x = self.fc2(x)  # (batch_size, 1000) -> (1000, 10)
        return x

4. 训练+测试

        使用交叉熵代价函数(CrossEntropyLoss)和自适应矩阵优化算法(Adam)训练数据。代码如下:

LR = 0.001  # 学习率
model = Net()  # 模型
crossEntropy_loss = nn.CrossEntropyLoss()  # 交叉熵代价函数
optimizer = optim.Adam(model.parameters(), LR)


def train():
    model.train()
    for i, data in enumerate(train_loader):
        inputs, labels = data  # 获得一个批次的数据和标签
        out = model(inputs)  # 获得模型预测输出(64张图像,10个数字的概率)
        loss = crossEntropy_loss(out, labels)  # 使用交叉熵损失函数时,可以直接使用整型标签,无须独热编码
        optimizer.zero_grad()  # 梯度清0
        loss.backward()  # 计算梯度
        optimizer.step()  # 修改权值


def test():
    model.eval()
    correct = 0
    for i, data in enumerate(test_loader):
        inputs, labels = data  # 获得一个批次的数据和标签
        out = model(inputs)  # 获得模型预测结构(64,10)
        _, predicted = torch.max(out, 1)  # 获得最大值,以及最大值所在位置
        correct += (predicted == labels).sum()  # 判断64个值有多少是正确的
    print("测试集正确率:{}\n".format(correct.item() / len(test_loader)))


# 训练20个周期
for epoch in range(20):
    print("Epoch:{}".format(epoch))
    train()
    test()

         运行,等待片刻后,输出测试集的正确率为:

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2189409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pikachu-Unsafe Fileupload-服务端check

MIME MIME是Multipurpose Internet Mail Extensions (多用途互联网邮件扩展类型)的缩写,用来表示文件、文档、或字节流的性质和格式。是设定某种扩展名的文件用一种应用程序来打开的方式类型,当该扩展名文件被访问的时候&#xff…

论文阅读笔记-A Comparative Study on Transformer vs RNN in Speech Applications

前言 介绍 序列到序列模型已广泛用于端到端语音处理中,例如自动语音识别(ASR),语音翻译(ST)和文本到语音(TTS)。本文着重介绍把Transformer应用在语音领域上并与RNN进行对比。与传统的基于RNN的模型相比,将Transformer应用于语音的主要困难之一是,它需要更复杂的配…

【hot100-java】【分割回文串】

回溯篇 class Solution {//ret是需要返回的结果//path是回溯过程中的记录private final List<List<String>> retnew ArrayList<>();private final List<String> path new ArrayList<>();private String s;public List<List<String>>…

建筑资质的未来发展趋势

&#x1f3d7;️建筑资质是建筑企业进入市场的通行证&#xff0c;它不仅关系到企业的竞争力&#xff0c;也影响着整个建筑行业的健康发展。随着政策的调整和技术的进步&#xff0c;建筑资质管理正面临着新的变革。 1. 资质管理的数字化转型&#xff1a;&#x1f310; 随着信息技…

JavaScript-上篇

JS 入门 JS概述 JavaScript&#xff08;简称JS&#xff09;是一种高层次、解释型的编程语言&#xff0c;最初由布兰登艾奇&#xff08;Brendan Eich&#xff09;于1995年创建&#xff0c;并首次出现在网景浏览器中。JS的设计初衷是为Web页面提供动态交互功能&#xff…

区块链可投会议CCF C--CT-RSA 2025 截止10.15 附2024录用率

Conference&#xff1a;The Cryptographers Track at RSA Conference (CT-RSA) CCF level&#xff1a;CCF C Categories&#xff1a;network and information security Year&#xff1a;2025 Conference time&#xff1a;San Francisco, California, USA • April 28–May …

930/105每日一题

算法 1 4,2,9,11, 4, 2,4 2,4,9 42 4 24 9 2&#xff08;0&#xff09; 4&#xff08;1&#xff09; 9&#xff08;2&#xff09; 11&#xff08;3&#xff09; 11&#xff08;0&#xff09;11&#xff08;1&#xff09; 9&#xff08;2&#xff09; 11&#xff08;3&#xf…

深度学习:基于MindSpore实现CycleGAN壁画修复

关于CycleGAN的基础知识可参考&#xff1a; 深度学习&#xff1a;CycleGAN图像风格迁移转换-CSDN博客 以及MindSpore官方的教学视频&#xff1a; CycleGAN图像风格迁移转换_哔哩哔哩_bilibili 本案例将基于CycleGAN实现破损草图到线稿图的转换 数据集 本案例使用的数据集里…

Qt系统学习篇(6)-QMainWindow

QMainWindow是一个为用户提供主窗口程序的类&#xff0c;包含一个菜单栏(menu bar)、多个工具栏(tool bars)、多个锚接部件(dock widgets)、一个状态栏(status bar)及一个中心部件(central widget)&#xff0c;是许多应用程序的基础&#xff0c;如文本编辑器&#xff0c;图片编…

webpack信息泄露

先看看webpack中文网给出的解释 webpack 是一个模块打包器。它的主要目标是将 JavaScript 文件打包在一起,打包后的文件用于在浏览器中使用,但它也能够胜任转换、打包或包裹任何资源。 如果未正确配置&#xff0c;会生成一个.map文件&#xff0c;它包含了原始JavaScript代码的映…

算法笔记(九)——栈

文章目录 删除字符串中的所有相邻重复项比较含退格的字符串基本计算机II字符串解码验证栈序列 栈是一种先进后出的数据结构&#xff0c;其操作主要有 进栈、压栈&#xff08;Push&#xff09; 出栈&#xff08;Pop&#xff09; 常见的使用栈的算法题 中缀转后缀逆波兰表达式求…

关注、取关、Redis实现共同关注、 博客推送与分页查询

Resourceprivate StringRedisTemplate stringRedisTemplate;Resourceprivate IUserService userService;Overridepublic Result follow(Long followUserId, Boolean isFollow) {//1.获取登陆的用户Long userId UserHolder.getUser().getId();//1.判断是关注还是取关if(isFollo…

基于Springboot+Vue的小区运动中心预约管理系统的设计与实现 (含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 这段数…

FPGA-UART串口接收模块的理解

UART串口接收模块 背景 在之前就有写过关于串口模块的文章——《串口RS232的学习》。工作后很多项目都会用到串口模块&#xff0c;又来重新理解一下FPGA串口接收的代码思路。 关于串口相关的参数&#xff0c;以及在文章《串口RS232的学习》中已有详细的描述&#xff0c;这里就…

Linux启动mysql报错

甲方公司意外停电&#xff0c;所有服务器重启后&#xff0c;发现部署在Linux上的mysql数据库启动失败.再加上老员工离职&#xff0c;新接手项目&#xff0c;对Linux系统了解不多&#xff0c;解决起来用时较多&#xff0c;特此记录。 1.启动及报错 1.1 启动语句1 启动语句1&a…

Java编程基础(Scanner类==>循环语句)

文章目录 前言一、Scanner类1.创建Scanner对象2.使用3.实践 二、if条件语句1.简单if语句2.if-else语句3.if-else if-else语句3.实践 三、switch 开关语句四、循环语句1.for语句2.while语句3.do-while语句4.break和continue语句 总结 前言 我们发现在学习Java语言编程基础时&am…

【GEE数据库】WRF常用数据集总结

【GEE数据库】WRF常用数据集总结 GEE数据集介绍数据1:MODIS数据集LAI(叶面积指数)和Fpar(绿色植被率)年尺度土地利用类型数据2:月反射率(Monthly Albedo)数据3:LULC和ISA参考GEE数据集介绍 GEE数据搜索网址-A planetary-scale platform for Earth science data &…

(PyTorch) 深度学习框架-介绍篇

前言 在当今科技飞速发展的时代&#xff0c;人工智能尤其是深度学习领域正以惊人的速度改变着我们的世界。从图像识别、语音处理到自然语言处理&#xff0c;深度学习技术在各个领域都取得了显著的成就&#xff0c;为解决复杂的现实问题提供了强大的工具和方法。 PyTorch 是一个…

9.30学习记录(补)

手撕线程池: 1.进程:进程就是运行中的程序 2.线程的最大数量取决于CPU的核数 3.创建线程 thread t1; 在使用多线程时&#xff0c;由于线程是由上至下走的&#xff0c;所以主程序要等待线程全部执行完才能结束否则就会发生报错。通过thread.join()来实现 但是如果在一个比…

08.STL简介

1. 什么是STL STL(standard template libaray-标准模板库)&#xff1a;是C标准库的重要组成部分&#xff0c;不仅是一个可复用的 组件库&#xff0c;而且是一个包罗数据结构与算法的软件框架 2.发展历史 1. 起源与早期探索&#xff08;20世纪80年代初期&#xff09;&#xff…