深度学习——CNN卷积神经网络

news2024/12/23 20:45:49

基本概念

概述

卷积神经网络(Convolutional Neural Network,CNN)是一种深度学习中常用于处理具有网格结构数据的神经网络模型。它在计算机视觉领域广泛应用于图像分类、目标检测、图像生成等任务。

核心思想

CNN 的核心思想是通过利用局部感知和参数共享来捕捉输入数据的空间结构信息。相比于传统的全连接神经网络,CNN 在网络结构中引入了卷积层和池化层,从而减少了参数量,并且能够更好地处理高维输入数据。

其他概念

输入层:接收原始图像或其他形式的输入数据。
卷积层(Convolutional Layer):使用卷积操作提取输入特征,通过设置滤波器(卷积核)在输入数据上滑动并执行卷积运算。这样可以学习到局部的特征,如边缘、纹理等。
激活函数(Activation Function):在每个卷积层后面通常紧跟一个非线性的激活函数,如ReLU(Rectified Linear Unit),以增加网络的非线性表达能力。
池化层(Pooling Layer):通过减少特征图的尺寸来降低模型复杂性。常用的池化操作是最大池化(Max Pooling),它选取每个池化窗口内的最大特征值作为输出。
全连接层(Fully Connected Layer):将卷积层和池化层的输出连接到全连接层,使用传统的神经网络模式进行分类、回归等任务。
Dropout 层:在训练过程中以一定概率随机将部分神经元的输出置为0,以减少模型的过拟合。
Softmax 层:多分类问题中常用的输出层,在最后一层进行 softmax 操作将输出转化为类别上的概率分布。

代码与详细注释

import os

# third-party library
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

# torch.manual_seed(1)    # reproducible

# Hyper Parameters
#  轮次
EPOCH = 1               # train the training data n times, to save time, we just train 1 epoch
# 批大小为50
BATCH_SIZE = 50
# 学习率
LR = 0.001
# 是否下载mnist数据集
DOWNLOAD_MNIST = False


# 下载minist数据集
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
    # not mnist dir or mnist is empyt dir
    DOWNLOAD_MNIST = True

# torchvision本身就是一个数据库
train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,                                     # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=DOWNLOAD_MNIST,
)

# 输出训练数据尺寸
print(train_data.train_data.size())                 # (60000, 28, 28)
# 输出标签数据尺寸
print(train_data.train_labels.size())               # (60000)
# 展示训练数据集中的第0个图片
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
# 图片的标题是标签
plt.title('%i' % train_data.train_labels[0])
plt.show()

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
# 批大小为50,shuffle为True意思是设置为随机
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

# pick 2000 samples to speed up testing
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# 使用unsqueeze增加一个维度
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.   # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 快速搭建神经网络
        self.conv1 = nn.Sequential(         # input shape (1, 28, 28)
            nn.Conv2d(
                in_channels=1,              # input height
                out_channels=16,            # n_filters
                kernel_size=5,              # filter size
                stride=1,                   # filter movement/step
                padding=2,                  # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
            ),                              # output shape (16, 28, 28)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(kernel_size=2),    # choose max value in 2x2 area, output shape (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # input shape (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 14, 14)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(2),                # output shape (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes

    # 前向传播
    def forward(self, x):
        # 第一层卷积
        x = self.conv1(x)
        # 第二层卷积
        x = self.conv2(x)
        x = x.view(x.size(0), -1)           # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output, x    # return x for visualization


cnn = CNN()
print(cnn)  # net architecture

# 选择优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)   # optimize all cnn parameters
# 选择损失函数
loss_func = nn.CrossEntropyLoss()                       # the target label is not one-hotted

# following function (plot_with_labels) is for visualization, can be ignored if not interested
from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
    plt.cla()
    X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
    for x, y, s in zip(X, Y, labels):
        c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
    plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()


# training and testing
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):   # gives batch data, normalize x when iterate train_loader

        output = cnn(b_x)[0]            # cnn output
        loss = loss_func(output, b_y)   # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients

        if step % 50 == 0:
            test_output, last_layer = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.numpy()
            accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
            if HAS_SK:
                # Visualization of trained flatten layer (T-SNE)
                tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
                plot_only = 500
                low_dim_embs = tsne.fit_transform(last_layer.data.numpy()[:plot_only, :])
                labels = test_y.numpy()[:plot_only]
                plot_with_labels(low_dim_embs, labels)
plt.ioff()

# print 10 predictions from test data
test_output, _ = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')



运行结果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/746004.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

环形链表(快慢指针)

给你一个链表的头节点 head ,判断链表中是否有环。如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数 pos 来表示链表尾连接到链表中的位置(索…

C++ | 继承

目录 前言 一、继承的基本概念与使用 1、继承的概念 2、继承的定义 3、继承的访问限定符与继承方式 二、基类与派生类之间的赋值转换(切片) 三、继承中的作用域 1、继承中的作用域 2、隐藏(重定义) 四、派生类的默认构…

知识付费小程序怎么做

知识付费小程序是一种通过在线平台提供知识和教育内容的应用程序。下面将详细介绍其功能: 1. 音频视频课程: 知识付费小程序提供了丰富的音频和视频课程,在这些课程中,用户可以通过观看或听取专业讲师的讲解来学习各种知识领域。…

【文章系列解读】Nerf

1. Nerf NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis 2020年8月3日 (0)总结 NeRF工作的过程可以分成两部分:三维重建和渲染。(1)三维重建部分本质上是一个2D到3D的建模过程&#xff…

Java习题之实现平方根(sqrt)函数

目录 前言 二分查找 牛顿迭代法 总结 🎁博主介绍:博客名为tq02,已学C语言、JavaSE,目前学了MySQL和JavaWed 🎥学习专栏: C语言 JavaSE MySQL基础 🎄博主链接:tq02的…

【第四章 flutter-初识flutter】

文章目录 一、目录结构二、创建一个flutter项目三、创建自定义组件四、Container组件 就是divalignment 内容对齐方式decoration 类似border 为BoxDecoration的类 五、Text属性六、image组件总结、 一、目录结构 android、ios各自平台的资源文件 lib 项目目录 linux macos PC平…

Linux宝塔Mysql读写分离配置,两台服务器,服务器存在多个库

Linux宝塔Mysql读写分离配置,两台服务器,服务器存在多个库 一、主库操作 #登录数据库,用root登录方便,用其他账号会提示权限不足,需要登录root给予权限 mysql -u root -p 密码#创建一个账号,供从库用该账…

电商企业需要部署WMS仓储管理系统吗

随着电子商务行业的迅速发展,电商企业面临着日益增长的订单量和复杂的物流流程。为了提高仓储管理的效率和准确性,许多电商企业开始考虑部署WMS仓储管理系统。然而,是否真的需要部署WMS仓储管理系统,仍然是一个值得探讨的问题。本…

vLLM大模型推理加速方案原理(PagedAttention)

一、vLLM 简介 vLLM 用于大模型并行推理加速,核心是 PagedAttention 算法,官网为:https://vllm.ai/。 vLLM 主要特性: 先进的服务吞吐量通过 PagedAttention 对注意力 key 和 value 进行内存管理对传入请求的批处理针对 CUDA 内…

纯代码和低代码的本质区别

一、前言 纯代码和低代码是现代软件开发中两种不同的方法。 纯代码需要专业的编程技能,掌握编程语言、算法和数据结构等专业知识。而低代码则是一种新兴的开发方法,它大大降低了对编程技能的要求,让非技术人员也能够创建应用程序。随着低代码…

【SpringBoot】--03.数据访问、基础特性(外部化和内部外配置、整合JUnit)

文章目录 SpringBoot3-数据访问1.整合SSM场景1.1创建SSM整合项目1.2配置数据源1.3配置MyBatis1.4CRUD编写 2.自动配置原理3.扩展:整合其他数据源3.1 Druid 数据源 SpringBoot3-基础特性1. SpringApplication1.1 自定义 banner1.2.自定义 SpringApplication1.3Fluent…

nvm 管理node 环境配置

nvm安装: nvm(Node Version Manager)是一个用来管理node版本的工具。我们之所以需要使用node,是因为我们需要使用node中的npm(Node Package Manager),使用npm的目的是为了能够方便的管理一些前端开发的包!…

ColorOS凭什么夺冠?

摘要:五大主流安卓系统流畅度PK,谁的体验最好? 评价一款手机,你最先看的是什么? 是处理器平台?CPU核心频率?还是内存配置? 虽然这些硬件参数能够清晰地反映几款不同配置手机之间的性…

20230712-----阻塞IO驱动按键控制LED灯的亮灭

驱动程序 #include <linux/init.h> #include <linux/module.h> #include <linux/fs.h> #include <linux/device.h> #include <linux/cdev.h> #include <linux/slab.h> #include <linux/uaccess.h> #include <linux/of.h> #in…

TeeChart for.NET Crack

TeeChart for.NET Crack TeeChart for.NET为各种图表需求提供了图表控件&#xff0c;包括金融、科学和统计等重要的垂直领域。它可以处理您的数据&#xff0c;在各种平台上无缝创建信息丰富、引人入胜的图表&#xff0c;包括Windows窗体、WPF、带有HTML5/Javascript渲染的ASP.N…

敢不敢和AI比猜拳?能赢算我输----基于手势识别的AI猜拳游戏【含python源码+PyqtUI界面+原理详解】-python手势识别 深度学习实战项目

功能演示 摘要&#xff1a;手势识别是一种通过技术手段识别视频图像中人物手势的技术。本文详细介绍了手势识别实现的技术原理&#xff0c;同时基于python与pyqt开发了一款带UI界面的基于手势识别的猜拳游戏。手势识别采用了mediapipe的深度学习算法进行手掌检测与手部的关键点…

字符设备驱动开发(最初方式)

目录&#xff1a; 1.字符设备驱动简介2.字符设备驱动开发步骤2.1. 驱动模块的加载与卸载2.2. Makefile的编写2.3.字符设备的注册与注销2.3.1.设备号的组成2.3.2.设备号的分配 2.4.具体操作函数的实现2.4.1.进行打开和关闭操作2.4.2.对chrdev进行读写操作 3.具体程序的实现3.1.驱…

第十一章——使用类

运算符重载 运算符重载是一种形式的C多态。之前介绍过的函数重载&#xff08;定义多个名称相同但特征标不同的函数&#xff09;让程序员能够用同名的函数来完成相同的基本操作&#xff0c;即使这些操作被用于不同的数据类型。 运算符重载将重载的概念扩展到运算符上&#xff0…

gulimall-性能监控与压力测试

性能监控与压力测试 前言一、性能监控1.1 jvm 内存模型1.2 jvisualvm 作用1.3 监控指标 二、压力测试2.1 概念2.2 性能指标2.3 JMeter 压测工具 前言 本文继续记录B站谷粒商城项目视频 P141-150 的内容&#xff0c;做到知识点的梳理和总结的作用。 一、性能监控 1.1 jvm 内存…

灯具小程序怎么制作

灯具小程序怎么制作&#xff0c;有什么功能 1. 商品展示&#xff1a;灯具小程序商城提供了丰富多样的灯具产品&#xff0c;并通过清晰的商品展示页面展示给用户。用户可以浏览不同种类的灯具&#xff0c;包括吊灯、台灯、壁灯等&#xff0c;了解产品的图片、规格、价格等详细信…