机器学习:卷积神经网络

news2025/1/11 20:07:29

卷积神经网络

  • 卷积神经网络的结构及原理
    • 卷积层
    • 池化层
    • 激活函数
    • 全连接层
    • 反馈运算
  • 使用MNIST数据集进行代码解析
    • 数据介绍
    • 实现流程
    • 代码实现

卷积神经网络的结构及原理

卷积层

卷积运算一个重要的特点就是:通过卷积运算,可以使原信号特征增强,并且降低噪音。
在这里插入图片描述
以二维为例,卷积核在二维平面上平移,对应位置相乘,得到一个新图像,即对图像的每个像素的邻域(邻域大小就是核的大小)加权求和得到该像素点的输出值。
在这里插入图片描述

池化层

通常使用的池化操作为平均值池化 (average-pooling)和最大值池化(max-pooling)。

池化层不包含需要学习的参数使用时仅需指定池化类型(average 或max)、池化操作的核大小(kernel size)和池化操作的步长等超参数即可。
在这里插入图片描述
!!!注意区分卷积和池化步幅的移动,都是需要定义的

池化层的引入是仿照人的视觉系统对视觉输入对象进行降维(下采样)和抽象。池化层有三种功效:

1.特征不变性。池化操作使模型更关注是否存在模型特征而不是特征具体的位置。
2.特征降维。
3.在一定程度上防止过拟合,更方便优化。

激活函数

又称非线性映射层。激活函数的引入为的是增加整个网络的表达能力(即非线性),否则,若干线性操作层的堆叠仍然只能起到线性映射的作用,无法形成复杂的函数。

激活函数应该具有的性质如下:

▨非线性。
▨连续可微。
▨范围最好不饱和,当有饱和的区间段时,若系统优化进入到该段,梯度近似为 0,网络的学习就会停止。
▨单调性。当激活函数是单调时,单层神经网络的误差函数是凸的,好优化。
▨在原点处近似线性,这样当权值初始化为接近0的随机值时,网络可以学习得较快,不用调节网络的初始值。
通常使用的激活函数有sigmoid、tanh和relu函数。

sigmoid函数

在这里插入图片描述
tanh 函数
relu 函数

在这里插入图片描述
relu 函数在梯度下降中能够快速收敛。relu函数有效缓解了梯度消失的问题,但是随着训练的继续,可能会出现神经元死亡,权重无法更新的情况。也即是说,relu 函数下的神经元在训练中不可逆地死亡了

全连接层

起到“分类器”的作用。如果说卷积层、池化层和激活函数层等操作是将原始数据映射到隐藏层特征空间的话,全连接层则起到将学到的“分布式特征表示”映射到样本标记空间的作用。

在基本的 CNN 网络中,全连接层的作用是将经过多个卷积层和池化层的图像特征图中的特征进行整合,获取图像特征具有的高层含义,之后用于图像分类。在 CNN 网络中,全连接层将卷积层产生的特征图映射成一个固定长度(一般为输入图像数据集中的图像类别数)的特征向量

反馈运算

调参的过程。

在卷积神经网络求解时,特别是针对大规模应用问题,常采用批处理的随机梯度下降法。批处理的随机梯度下降法在训练模型阶段随机选取 n 个样本作为一批样本,先通过前馈运算得到预测并计算其误差,后通过梯度下降法更新参数,梯度从后往前逐层反馈,直至更新到网络的第一层参数,这样的一个参数更新过程称为一个“批处理过程”。

批处理样本的大小不宜设置过小。过小时,由于样本采样随机,按照该样本上的误差更新模型参数不一定在全局上最优(此时仅为局部最优更新),会使得训练过程产生振荡。批处理大小的上限则主要取决于硬件资源的限制,如 GPU 显存大小。

使用MNIST数据集进行代码解析

MNIST数据集算是机器学习入门的数据,常用来做分类处理。总的标签类别就是0到9这十个数字,图片之间不同的是采用的是一些人群的手写体,目的就是根据不同手写体识别图片进而判断数字是几。

数据介绍

在这里插入图片描述
在MNIST数据集官网http://yann.lecun.com/exdb/mnist/ 下载后的压缩包不用解压缩,可直接放在一个文件夹下,直接调用。

数据包含55000个训练数据集,5000个验证数据集,10000个测试数据集(我还不知道验证数据集有什么用),每张图片有28*28个像素点,即 784 个像素点,我们可以把它展开形成一个向量,即长度为 784 的向量。

实现流程

先定义好权重偏置函数,卷积函数tf.nn.conv2d,池化函数(最大池化max_pool_2x2)

卷积 -> 激活函数 -> 池化 -> 全连接 -> 激活函数 -> Dropout层 -> AdamOptimizer优化器(反馈) -> 训练模型 -> 评估

卷积:突出局部的特征
激活函数:非线性处理,增加表达能力,输入的是(卷积结果 + 偏置)
池化:降维
全连接:一步步收敛,映射到标签y
Dropout层:防止过拟合,扔到部分神经元
AdamOptimizer优化器:模型复杂计算量大,直接使用此优化器进行反馈定义合适的权重和偏置值
训练模型:上面搭建好的框架,给入数据进行模型训练
评估:引入测试集数据,评估模型

代码实现

# -*- coding: utf-8 -*-
"""
Created on Thu Nov 17 21:30:31 2022

@author: Yangz
"""

'''
卷积 -> 激活函数 -> 池化 -> 全连接 -> 激活函数 -> Dropout层 -> AdamOptimizer优化器(反馈) -> 训练模型 -> 评估

卷积:突出局部的特征
激活函数:非线性处理,增加表达能力,输入的是(卷积结果 + 偏置)
池化:降维
全连接:一步步收敛,映射到标签y
Dropout层:防止过拟合,扔到部分神经元
AdamOptimizer优化器:模型复杂计算量大,直接使用此优化器进行反馈定义合适的权重和偏置值
训练模型:上面搭建好的框架,给入数据进行模型训练
评估:引入测试集数据,评估模型

'''

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow.compat.v1 as tf
#tensorflow版本原因2.0往上的用这种方式
tf.compat.v1.disable_eager_execution()

'''
读取MNIST数据,MNIST数据集中每一张图片大小都是28*28的,因此有784个像素点
所以输入值x应该是一个长度为784的向量,y_是标签长度为10

'''
mnist = input_data.read_data_sets("./MNIST_data",one_hot=True)
sess = tf.InteractiveSession() #与BP不同的另一种sess定义

train_nums = mnist.train.num_examples #训练集样本数据大小55000
validation_nums = mnist.validation.num_examples # 验证集样本数据大小5000
test_nums = mnist.test.num_examples  #测试集样本数据大小10000

# 设定两个参数,输入x、实际输出y_的占位符;x是数据,y_是标签
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

keep_prob = tf.placeholder(tf.float32) #改变参与计算的神经元个数的值,全连接层的输出
#x_image = tf.reshape(x, [-1,28,28,1])  #对输入数据x进行变形

'''
权重、偏置函数
'''
def weight_variable(shape):
    # 产生随机变量,
    #truncated_normal()函数:选取位于正态分布均值=0.1附近的随机值
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape) #常量
    return tf.Variable(initial)

'''
卷积函数、池化函数
'''
#输入图像与卷积核,W=[filter_height, filter_width, in_channels, out_channels]
#具体含义是[卷积核的高度,卷积核的宽度,图像通道数,输出通道数]
def conv2d(x, W): 
    #stride = [1,in_height移动步长,in_width移动步长,1]
    #卷积层conv2d()函数里strides参数要求第一个、最后一个参数必须是1,即只能在一个样本的一个通道上的特征图上进行移动
    #padding:string类型的量,只能是"SAME","VALID"其中之一,
    #当其为‘SAME’时,表示卷积核可以停留在图像边缘,保留边界信息和图像大小
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):#选用最大池化方法,池化步长在这里固定为2
    #k_size : 池化窗口的大小,取一个四维向量,一般是[1, height, width, 1]
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

#tf.reshape:哪一维使用了-1,那这一维度就不定义大小,而是根据你的数据情况进行匹配
x_image = tf.reshape(x, [-1,28,28,1])  #对输入数据x进行变形

'''
第一层卷积 + 池化
卷积核大小5*5,输入通道数1,输出通道数(深度)32,
'''
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
#激活函数使用reLU进行非线性处理,大小28*28*32
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) #卷积后的结果加上偏置值
#第一次池化 输出结果大小14*14*32
h_pool1 = max_pool_2x2(h_conv1)

'''
第二层卷积 + 池化
卷积核大小5*5,第输入的通道数是32,输出的通道数是64
'''
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
#第二次卷积,输出结果大小14*14*64
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
# 第二次池化,输出结果大小 7 *7 *64
h_pool2 = max_pool_2x2(h_conv2)

'''
全连接层,这里构造了两个全连接层
'''
#第一个全连接层,转化为一维,定义输出节点数1024

W_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64])  #三维变形为二维
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) #tf.matmul矩阵相乘全连接
#防止过拟合,使用Dropout层,是在不同的训练过程中随机扔掉一部分神经元,
#仅仅是不参与计算,但权重保留,之后样本输入时依旧工作,依旧进行权重更新
#train的时候才是dropout起作用的时候,test的时候不应该让dropout起作用
#tf.nn.dropout: keep_prob,每个元素被保留下来的概率
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 

#第二个全连接层,收敛对应10个标签,最后一个全连接层也可记为分类的过程
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
prediction = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

'''
反馈调参、求准确率
'''
#计算平均误差
#tf.nn.softmax_cross_entropy_with_logits:交叉熵损失函数,第一个为实际值,第二个为预测值
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=prediction))
#对于神经网络部分的反馈运算,即权重、偏置更新过程,由于计算量庞大,
#一般都直接使用AdamOptimizer优化器,BP神经网络学习中已有介绍
#minimize(loss)根据其损失量学习自适应,损失量大则学习率大
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
#结果存放在一个布尔型列表中,便于后续计算准确率
correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y_,1))
#求准确率
#tf.cast:数据类型转换,这里是将布尔类型值correct_prediction转换为浮点型
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

'''
模型训练
'''
saver = tf.train.Saver()  # defaults to saving all variables
init_op = tf.global_variables_initializer()
sess.run(init_op)


#tf.global_variables_initializer().run()
for i in range(1000):
    #每次训练固定训练集大小,10
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})#只有一半的神经元参与计算
    #每100次计算一次准确率
    if i%100 == 0:
        #accuracy.eval 相当于 sess.run(accuracy,…)
        train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
        print("step",i, "training accuracy",train_accuracy)
    
    #batch[0] = mnist.train.images
    #batch[1] = mnist.train.labels
    
#保存模型参数
#saver.save(sess, 'F://小组//Python//机器学习//卷积神经网络//model.ckpt')
#使用测试机数据估计模型准确率
print("test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})) #所有的神经元都参与评估计算

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/19826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

服务器常用的异常及性能排查

服务器常用的异常及性能排查 使用 top 命令查看性能指标 top 命令使用详细介绍:传送门 查看Tasks total 进程数 正常我们在使用过程中对每天的一个进程数大概是有一个谱的,比如正常就是1百多个,突然暴增几百,那就很明显这里有…

计算机网络:运输层

运输层 运输层主要解决了应用进程之间的通信,称之为端到端协议 1.运输层概述 计算机网-------络体系结构的角度 AP:应用进程之间的简称 2. 运输层端口号、复用与分用的概念 2.1 端口号 2.2 发送方的复用(multiplexing)和接收方的分用&…

【Java面试八股文宝典之基础篇】备战2023 查缺补漏 你越早准备 越早成功!!!——Day10

大家好,我是陶然同学,软件工程大三明年实习。认识我的朋友们知道,我是科班出身,学的还行,但是对面试掌握不够,所以我将用这100多天更新Java面试题🙃🙃。 不敢苟同,相信大…

锐捷RLDP理论及实验讲解

RLDP概念 RLDP(Rapid Link Detection Protocol)是一个用于快速检测以太网链路故障的链路协议,包括环路链路故障、单向链路故障、双向链路故障等 工作原理 RLDP定义了两种协议报文:探测报文(Probe)和探测响…

【Java第32期】:Spring 中普通Maven项目的创建

作者:有只小猪飞走啦 博客地址:https://blog.csdn.net/m0_62262008?typeblog 内容:Spring 中普通Maven项目的创建 文章目录前言一,创建Spring项目1.创建一个普通的Maven项目2,添加Spring框架3,添加启动类…

【数据库系统概论】关系数据理论、范式

数据库一二三范式简单解释 第一范式 一个关系模式应当是一个五元组。 R(U,D,DOM,F)R(U,D,DOM,F)R(U,D,DOM,F) 这里: 关系名RRR是符号化的元组语义UUU为一组属性DDD为属性组UUU中的属性所来自的域DOMDOMDOM为属性到域的映射FFF为属性组UUU上的一组数据依赖 由于D…

RabbitMQ_概述

RabbitMQ大致工作流程图 解释 Producer:生产者 Consumer:消费者 Connection:AMQP协议连接 Channel:信道,进行消息读写的通道,RabbitMQ的绝大部分操作在信道完成;客户端可以建立多个信道&…

用 AWTK 和 AWPLC 快速开发嵌入式应用程序 (4)- 自定义功能块(上)

AWPLC 目前还处于开发阶段的早期,写这个系列文章的目的,除了用来验证目前所做的工作外,还希望得到大家的指点和反馈。如果您有任何疑问和建议,请在评论区留言。 1. 背景 AWTK 全称 Toolkit AnyWhere,是 ZLG 开发的开源…

PTA题目 两个数的简单计算器

本题要求编写一个简单计算器程序,可根据输入的运算符,对2个整数进行加、减、乘、除或求余运算。题目保证输入和输出均不超过整型范围。 输入格式: 输入在一行中依次输入操作数1、运算符、操作数2,其间以1个空格分隔。操作数的数…

跟艾文学编程《Python基础》(7)pandas数据分析

作者: 艾文,计算机硕士学位,企业内训讲师和金牌面试官,公司资深算法专家,现就职BAT一线大厂。邮箱: 1121025745qq.com博客:https://wenjie.blog.csdn.net/内容:跟艾文学编程《Python…

汉字风格迁移篇---W-net:基于深度神经网络的一次任意风格汉字生成

文章目录一、摘要二、提出原因已有的一些模型解决方案依然存在的限制三、介绍与创新四、模型介绍预处理w-net结构优化策略和损失函数五、实验实验设置用zi2zi作为基线具体实现1、 W-Net训练期间的超参数设置如下:2、一些细节处理模型评估W-net、zi2zi-v1、zi2zi-v2不…

第2-3-7章 个人网盘服务接口开发-文件存储服务系统-nginx/fastDFS/minio/阿里云oss/七牛云oss

文章目录5.8 导入其他接口代码5.8.1 接口导入-分页查询附件5.8.2 接口导入-根据业务类型/业务id查询附件5.9 导入网盘服务接口5.9.1 导入FileController5.9.2 导入StatisticsController5.9.3 导入FileRestManager5.9.4 导入FileService5.9.5 导入FileServiceImpl5.9.6 扩展File…

面向OLAP的列式存储DBMS-8-[ClickHouse]的常用聚合函数

ClickHouse 中的常用聚合函数 1 聚合函数 ClickHouse 中的聚合函数,因为和关系型数据库的相似性,本来聚合函数不打算说的,但是 ClickHouse 提供了很多关系型数据库中没有的函数,所以我们还是从头了解一下。 1.1 count count&…

Vue3 用src动态引入本地图片

💭💭 ✨: Vue3 用src动态引入本地图片   💟:东非不开森的主页   💜: 躲起来的星星也在努力发光 你也要💜💜   🌸: 如有错误或不足之处,希望可以指正&#…

Qt OpenGL(二十二)——Qt OpenGL 核心模式-VAO和VBO

Qt OpenGL(二十二)——Qt OpenGL 核心模式-VAO和VBO 一、再谈VAO、VBO 上一篇文章,通过VAO、VBO绘制了一个三角形,过程需要创建VAO、VBO和释放。之所以有这些步骤,就是因为OpenGL本质就是一个大的状态机。但是我们如果要继续学习核心模式的OpenGL的话,VAO、VBO是我们必…

Java集合(一):泛型与Collection集合

目录 集合预热:泛型 泛型的优点 自定义泛型类型 自定义泛型类/接口 泛型使用细节 自定义泛型方法 泛型与继承关系 不存在继承关系的情况 通配符与存在继承关系的情况 泛型受限 集合概述 集合的作用与存储内容 集合与数据结构 集合:Collectio…

【基础算法系列】离散化与前缀和算法的运用

⭐️前面的话⭐️ 本篇文章将主要介绍离散化算法,所谓离散化算法,就是将一个无限区间上散点的数,在不改变相对大小的情况下,映射到一个较小的区间当中,然后对这个较小的区间进行操作的过程就是离散化的过程&#xff0…

【C++笔试强训】第二十八天

🎇C笔试强训 博客主页:一起去看日落吗分享博主的C刷题日常,大家一起学习博主的能力有限,出现错误希望大家不吝赐教分享给大家一句我很喜欢的话:夜色难免微凉,前方必有曙光 🌞。 💦&a…

微信小程序自定义tabBar(实操)

文章目录一、前言二、固定效果图实现步骤实现步骤完整代码-矢量图images图片app.json代码三、自定义效果图实现步骤实现步骤完整代码-矢量图images图片app.json代码custom-tab-bar下的代码使用自定义TaBar一、前言 一般使用tabBar的样式,固定不能改变。如下固定效果…

java计算机毕业设计springboot+vue村委会管理系统

项目介绍 本村委会管理系统是针对目前村委会管理的实际需求,从实际工作出发,对过去的村委会管理系统存在的问题进行分析,完善用户的使用体会。采用计算机系统来管理信息,取代人工管理模式,查询便利,信息准确率高,节省了开支,提高了工作的效率。 本系统结合计算机系统的结构、概…