神经网络--从0开始搭建全连接网络和CNN网络

news2024/9/25 19:25:27

在这里插入图片描述

前言: Hello大家好,我是Dream。 今天来学习一下如何从0开始搭建全连接网络和CNN网络,并通过实验简单对比一下两种神经网络的不同之处,本文目录较长,可以根据需要自动选取要看的内容~

本文目录:

  • 一、搭建4层全连接神经网络
    • 1.调用库函数
    • 2.选择模型,构建网络
    • 3.编译(使用交叉熵作为loss函数)
    • 4.输出
    • 5.画出图像
    • 6.结论
  • 二、搭建CNN网络
    • 1.调用库函数
    • 2.调用数据集
    • 3.图片归一化
    • 4.选择模型,构建网络
    • 5.编译
    • 6.批量输入的样本个数
    • 7.训练
    • 8.输出
    • 9.画出图像
    • 10.结论
  • 三、两种网络对比
  • 四、源码获取

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下四个任务:

  • 加载keras内置的mnist数据库

  • 自己搭建简单神经网络,并自选损失函数和优化方法

  • 搭建4层全连接神经网络,除输入层以外,各层神经元个数分别为1000,300,64,10,激活函数自选

  • 搭建CNN网络,要求有1个卷积层(32卷积核),1个池化层(2x2),1个卷积层(16卷积核),1个全局池化层(globalMaxPool),一个全连接输出层,激活函数自选

一、搭建4层全连接神经网络

加载keras内置的mnist数据库,搭建4层全连接神经网络,除输入层以外,各层神经元个数分别为1000,300,64,10,激活函数自选

1.调用库函数

import tensorflow as tf
import matplotlib.pyplot as plt
mnist = tf.keras.datasets.mnist
from tensorflow.keras.layers import Flatten,Dense,Dropout

2.选择模型,构建网络

搭建4层全连接神经网络,除输入层以外,各层神经元个数分别为1000,300,64,10

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 选择模型,构建网络
model = tf.keras.models.Sequential()
model.add(Flatten(input_shape=(28, 28)))
# 各层神经元个数分别为1000,300,64,10
model.add(Dense(1000, activation='relu'))
model.add(Dense(300, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.2))  # 采用20%的dropout
model.add(Dense(10, activation='softmax'))  # 输出结果是10个类别,所以维度是10,最后一层用softmax作为激活函数

3.编译(使用交叉熵作为loss函数)

指明优化器、损失函数、准确率计算函数

# 编译(使用交叉熵作为loss函数),指明优化器、损失函数、准确率计算函数
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 训练(训练10个epoch)
history = model.fit(x_train, y_train, epochs=10)

这里是训练的结果:
在这里插入图片描述

4.输出

输出测试集上的预测准确率

# 输出
scores = model.evaluate(x_test,y_test)
print(scores)
print("The accuracy of the model is %f" % scores[1])  #输出测试集上的预测准确率

这里是输出的结果:
在这里插入图片描述

5.画出图像

使用plt模块进行数据可视化处理

# 画出图像
plt.plot(history.history['loss'], color='red', label='Loss')
plt.legend(loc='best')
plt.title('Training Loss')
plt.show()

在这里插入图片描述

6.结论

第一种神经网络准确率:0.976200

二、搭建CNN网络

要求有1个卷积层,1个池化层,1个全局池化层(globalMaxPool),一个全连接输出层,激活函数自选

1.调用库函数

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import math
from tensorflow.keras.layers import Conv2D,MaxPooling2D,GlobalMaxPooling2D,Flatten,Dense

2.调用数据集

加载keras内置的mnist数据库

# 调用数据集
dataset, metaset = tfds.load('mnist', as_supervised=True, with_info=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

3.图片归一化

# 图片归一化
def normalize(images, labels):
    images = tf.cast(images, tf.float32)
    images /= 255
    return images, labels
train_dataset = train_dataset.map(normalize)
test_dataset = test_dataset.map(normalize)

4.选择模型,构建网络

构建出1个卷积层,1个池化层,1个全局池化层(globalMaxPool),一个全连接输出层

# 选择模型,构建网络
model = tf.keras.Sequential()

# 卷积层
model.add(Conv2D(32, (5, 5), padding='same', activation=tf.nn.relu, input_shape=(28, 28, 1))),  

# 池化层 
model.add(MaxPooling2D((2, 2), strides=2)), 

# 全局池化层(globalMaxPool)
model.add(Conv2D(64, (5, 5), padding='same', activation=tf.nn.relu)),  # 卷积层
model.add(GlobalMaxPooling2D()),

 # 全连接输出层
model.add(Flatten()),#展平
model.add(Dense(512, activation=tf.nn.relu)),
model.add(Dense(10, activation=tf.nn.softmax))# 输出结果是10个类别,所以维度是10,最后一层用softmax作为激活函数

5.编译

指明优化器、损失函数、准确率计算函数

# 编译(使用交叉熵作为loss函数),指明优化器、损失函数、准确率计算函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 展示训练的过程
display(model.summary())

这里是输出的结果:
在这里插入图片描述

6.批量输入的样本个数

# 批量输入的样本个数
BATCH_SIZE = 64
num_train = metaset.splits['train'].num_examples
num_test = metaset.splits['test'].num_examples
train_dataset = train_dataset.repeat().shuffle(num_train).batch(BATCH_SIZE)
test_dataset = test_dataset.repeat().shuffle(num_test).batch(BATCH_SIZE)

7.训练

训练10个epoch

# 训练(训练10个epoch)
history = model.fit(train_dataset, epochs=10, steps_per_epoch=math.ceil(num_train / BATCH_SIZE))

这里是输出的结果:
在这里插入图片描述

8.输出

# 输出
test_loss, test_accuracy = model.evaluate(test_dataset, steps=math.ceil(num_test / BATCH_SIZE))
print(test_loss, test_accuracy)

这里是输出的结果:
在这里插入图片描述

9.画出图像

使用plt模块进行数据可视化处理

# 画出图像
plt.plot(history.history['loss'], color='red', label='Loss')
plt.legend(loc='best')
plt.title('Training Loss')
plt.show()

这里是输出的结果:
在这里插入图片描述

10.结论

第二种神经网络准确率:0.993232

三、两种网络对比

第一种神经网络准确率:0.976200 第二种神经网络准确率:0.993232
总结: 通过对比我们可以发现CNN卷积神经网络相对于传统神经网络NN准确率会高一些,由卷积的操作可知,输出图像中的任何一个单元,只跟输入图像的一部分有关系。而传统神经网络中,由于都是全连接,所以输出的任何一个单元,都要受输入的所有的单元的影响。这样无形中会对图像的识别效果大打折扣,因此CNN在此种方面会更具优势

四、源码获取

关注此公众号:人生苦短我用Pythons,回复 神经网络实验获取源码,快点击我吧

🌲🌲🌲 好啦,这就是今天要分享给大家的全部内容了,我们下期再见!
❤️❤️❤️如果你喜欢的话,就不要吝惜你的一键三连了~
在这里插入图片描述
在这里插入图片描述

最后,有任何问题,欢迎关注下面的公众号,获取第一时间消息、作者联系方式及每周抽奖等多重好礼! ↓↓↓

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/187298.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spark核心RDD详解(设计与运行原理,分区,创建,转换,行动与持久化)

RDD设计背景与概念 在实际应用中,存在许多迭代式算法(比如机器学习、图算法等)和交互式数据挖掘工具,这些应用场景的共同之处是,不同计算阶段之间会重用中间结果,即一个阶段的输出结果会作为下一个阶段的输…

go: GOPATH entry is relative; must be absolute path: “F:oocode“.

系列文章目录 文章目录系列文章目录前言一、可以先查看一下啊二、gopath和goroot变量要和设置的一致总结前言 在安装hertz 之类的 总会弹出go 的不合法 等 出现这样的错误 要不就是go的不合法 会爆红 说go无这种命令 go:术语“ go”未被识别为cmdlet,函…

Hystrix断路器

目录 一、概述 (一)分布式系统面临的问题 (二)Hystrix是什么 (三)能干吗 (四)官网 (五)Hystrix官宣,停更进维 二、Hystrix重要概念 &…

JAVA开发(springBoot之HikariDataSource)

HikariDataSource是springBoot自带的数据源管理工具。应该是有zaxxer公司提供贡献给spring社区的。它是一款优秀的数据库连接池工具(新的东西一般会吹吹牛),号称 Java WEB 当前速度最快的数据源,相比于传统的 C3P0 、DBCP、Tomcat…

【数据结构】认识顺序表

目录 1、先来认识一下线性表 1.1、对非空的线性表或者线性结构的特点: 1.2、线性表的实现方式 2、顺序表 2.1、定义一个类,实现顺序表 2.2、顺序表的操作方法 2.2.1、打印顺序表(display) 2.2.2、获取顺序表的长度&#x…

Rancher 中使用 Longhorn 备份恢复数据卷

全文目录导航0. 前言1. NFS 安装配置1.1 安装 nfs 及 rpcbind1.2 创建共享目录1.3 配置访问权限1.4 限制 showmount -e 防止漏洞扫描1.5 防火墙配置2. Longhorn 备份配置2.1 在 Longhorn UI 中配置3. 数据卷备份恢复操作3.1 创建示例工作负载3.2 创建测试数据3.3 创建数据卷备份…

车载以太网 - SomeIP测试专栏 - SomeIP Header - 03

前面已经简单的介绍了整帧SomeIP报文的组成部分,由于Ethernet报文头都是通用的,因此不会做详细的介绍,当然后面在介绍TC8中的TCP、UDP、IPv4、IPv6的时候也会做简单的介绍。不过在这里就不做介绍了,我们直接介绍SomeIP。 SomeIP Header 一、Message ID Message ID是由Serv…

Web3中文|构建Web3融资交易:股权和内部代币分配的比例

2017年,首次币发行(ICO,Initial Coin Offering,也称首次代币发售、区块链众筹,是用区块链把使用权和加密货币合二为一,来为开发、维护、交换相关产品或者服务的项目进行融资的方式)的融资方式激…

聚观早报 | 抖音超市上线;首架国产大飞机 C919 完成首次飞行

今日要闻:抖音超市上线;首架国产大飞机 C919 完成首次飞行;小鹏汽车计划有 5 款车型上市;2023年春节档电影总票房67.58亿元;亚洲首富被空头重创抖音超市上线 1 月 28 日消息,抖音超市已上线抖音 App&#x…

Javadoc(文档注释)详解

Java 支持 3 种注释,分别是单行注释、多行注释和文档注释。文档注释以/**开头,并以*/结束,可以通过 Javadoc 生成 API 帮助文档,Java 帮助文档主要用来说明类、成员变量和方法的功能。文档注释只放在类、接口、成员变量、方法之前…

vue+element高度仿照QQ音乐,完美实现PC端QQ音乐

一.前言 QQ音乐官网:点击访问作者成品效果预览:点击访问作者其他博客成品汇总预览:点击访问 暂时源码并没有提供其他获取渠道,私聊作者获取即可,或通过博客后面名片添加作者,很简单! 二.主要…

创建的vue项目--打包

自创建的项目(未使用项目框架),使用weabpack打包 1.在package.json文件中配置 2.在控制台执行打包命令npm run build 打包完成后,会在项目中生成一个dist文件夹,其中就是打包生成的静态文件 3.打开index.html&…

RocketMq基础详解

1、RocketMq的架构: 在RocketMq中有四个部分组成,分别是Producer,Consumer,Broker,以及NameServer,类比于生活中的邮局,分别是发信者,收信者,负责暂存,传输的…

找到二叉树中的最大搜索二叉树

题目 给定一棵二叉树的头节点 head,一致其中所有节点的值都不一样,找到含有节点最多的搜索二叉树,并返回这棵子树的头节点。 示例 分析 树形dp套路:如果题目求解目标是S规则,则求解流程可以定成以每一个节点为头节点…

【前端】如何判断是页面滚动还是窗口滚动

在写项目的时候遇到这个问题&#xff0c;现在举两个例子来记录这个问题。 页面滚动 html: <div class"temp"><template v-for"item in 100"><div>{{ item }}</div></template> </div>css: .temp {height: 100px;o…

老马闲评数字化「3」业务说了算还是技术说了算?

原文作者&#xff1a;行云创新CEO 马洪喜 导语 前两集和大伙聊了一下“数字化不转型行不行”以及“你的企业急不急着转”这两个话题。后面收到不少朋友的消息&#xff0c;说写的挺好&#xff0c;但“急着转、不敢转”的情况非常的普遍&#xff0c;有没有啥好主意给说一说。 麦…

冬去春来,ToB行业压缩的弹簧就要迸发了

目前来看&#xff0c;认知、实践、技术、服务这四方面的新变化&#xff0c;都将成为2023年企业数智化业务需求“井喷”的重要原因。 作者|周羽 出品|产业家 2023&#xff0c;冬去春来。 不止于字面。新的一年&#xff0c;中国的ToB厂商即将迎来“拨云见日”的朗朗晴空。 …

[文件上传工具类] MultipartFile 统一校验

目录​​​​​​​ 1. 创建上传文件的统一校验类 包含功能: -> 1. 多文件上传校验 -> 2. 文件名字校验(特殊符号) -> 3. 文件后缀校验 2. 使用方式 建议: 在文件上传开始的位置添加 -> 两个重载方法, 单文件 多文件都支持 -> 示例: 直接可以用, 任意位…

C++ 包装器function

目录 1、为什么需要包装器&#xff1f; 2、包装器的声明和使用 (1) 声明 (2) 实际应用 (3) 包装器接收类成员函数 3、包装器的绑定&#xff1a;bind函数 (1) 调整参数顺序 (2) 调整参数个数 1、为什么需要包装器&#xff1f; 函数模板可以接收各种不同类型的参数&…

光流正负值的含义以及如何利用光流进行warping

本文主要介绍光流的形式&#xff0c;光流值的正负代表什么含义&#xff0c;以及如何利用光流进行warping。 1. 光流正负值的含义 光流的概念&#xff1a;光流表示的是从reference frame到target frame&#xff0c;物体的移动。光流的形式&#xff1a;光流的表示也是数字化的。…