搭建不同网络训练MNIST

news2024/11/17 11:51:59

问题
在之前的学习过程中,我们学习了如何搭建全连接神经网络训练Mnist数据集。初始时,全连接神经网络训练结果验证集和训练集的精确度不高,在对数据进行归一化,调参等操作提高了精确度。我们这次使用Le-Net5和VGG对MNIST进行训练,VGG采样层太多,计算量庞大,我们只进行搭建,也可以采用Google Colab进行训练。比较全连接和卷积神经网络异同。
Le-net5网络如下。
90fa7fd6b54fff01a62ada84a65fec7b.png

方法
搭建le-net5神经网络,定义了两层卷积层、两个均值池化以及三层全连接网络。使用summary可查看得到每层网络的输入和输出。

import torch
from torch import nn
from torchinfo import summary
import torch.nn.functional as F
import torchvision
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
from torchvision.transforms import Normalize
import time
start=time.time()
class MyNet(nn.Module):
   # x=1*28*28 c*h*w c表示通道数,h表示图像高度,w表示图像宽度
   #定义网络有哪些层,层作为成员变量
   def __init__(self) -> None:
       super().__init__()
       self.conv1=nn.Conv2d(in_channels=1,
                  out_channels=6,
                  kernel_size=5,
                  stride=1,
                  padding=0)#[28]
       self.avg_pool1 = nn.AvgPool2d(
           2,
           stride=2,
       )
       self.conv2=nn.Conv2d(in_channels=6,
                  out_channels=16,
                  kernel_size=5,
                  stride=1,
                 )
       self.avg_pool2 = nn.AvgPool2d(
           2,
           stride=2,
       )
       self.fc1=nn.Linear(in_features=4 * 4 * 16,
                       out_features=120,
                                   )
       self.fc2=nn.Linear(in_features=120,
                       out_features=84,
                                   )
       self.fc3 = nn.Linear(in_features=84,
                            out_features=10,
                            )
   def forward(self,x):
       x=torch.relu(self.conv1(x))
       x=torch.relu(self.avg_pool1(x))
       x=torch.relu(self.conv2(x))
       x=torch.relu(self.avg_pool2(x))
       x=torch.flatten(x,1)# nn.Flatten()默认从dim=1开始 torch.flatten()默认从dim=0开始
       x=torch.relu(self.fc1(x))
       x=torch.relu(self.fc2(x))
       out=torch.relu(self.fc3(x))
       return out

搭建VGG11,定义了八个卷积层,五个最大池化层和三层全连接层。由于网络搭建复杂,计算量过大,运行很慢。VGG模型框架如下。
5b4a450578eced94fd1a3c975400448b.png

class MyNet(nn.Module):
   # x=1*28*28 c*h*w c表示通道数,h表示图像高度,w表示图像宽度
   #定义网络有哪些层,层作为成员变量
   def __init__(self) -> None:
       super().__init__()
       self.conv1=nn.Conv2d(in_channels=1,
                  out_channels=64,
                  kernel_size=3,
                  stride=1,
                  padding=1)#[B,64,224,224]
  self.MaxPool_1=nn.MaxPool2d(kernel_size=2,
                                 stride=2)#[64,112,112]
       self.conv2=nn.Conv2d(in_channels=64,
                            out_channels=128,  # [B,128,112,112]
                           kernel_size=3,
                            stride=1,
                            padding=1)
       self.MaxPool_2= nn.MaxPool2d(kernel_size=2,
                                 stride=2)#[128,56,56]
       self.conv3=nn.Conv2d(in_channels=128,
                            out_channels=256,  # [B,256,56,56]
                           kernel_size=3,
                            stride=1,
                            padding=1)
       self.conv4 = nn.Conv2d(
           in_channels=256,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1#[256,56,56]
       )
       self.MaxPool_3 = nn.MaxPool2d(
           kernel_size=2,
           stride=2)#[256,28,28]
       self.conv5 = nn.Conv2d(
           in_channels=256,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1#[512,56,56]
       )
       self.conv6 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1#[512,56,56]
       )
       self.MaxPool_4 = nn.MaxPool2d(
           kernel_size=2,
           stride=2)#[512,14,14]
       self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,#[512,28,28]
           padding=1
       )
       self.conv8 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,#[512,56,56]
           padding=1
       )
       self.MaxPool_5=nn.MaxPool2d(
           kernel_size=2,
           stride=2)#[512,7,7]
       self.fc1 = nn.Linear(
           in_features=512 * 7 * 7,
           out_features=4096
       )
       self.fc2 = nn.Linear(
           in_features=4096,
           out_features=4096
       )
       self.fc3 = nn.Linear(
           in_features=4096,
           out_features=10
       )
   def forward(self,x):
       x=torch.relu(self.conv1(x))
       x=torch.relu(self.MaxPool_1(x))
       x=torch.relu(self.conv2(x))
       x=torch.relu(self.MaxPool_2(x))
       x = torch.relu(self.conv3(x))
       x = torch.relu(self.conv4(x))
       x = torch.relu(self.MaxPool_3(x))
       x = torch.relu(self.conv5(x))
       x = torch.relu(self.conv6(x))
       x = torch.relu(self.MaxPool_4(x))
       x = torch.relu(self.conv7(x))
       x = torch.relu(self.conv8(x))
       x = torch.relu(self.MaxPool_5(x))
       x=torch.flatten(x,1)# nn.Flatten()默认从dim=1开始 torch.flatten()默认从dim=0开始
       x=torch.relu(self.fc1(x))
       x=torch.relu(self.fc2(x))
       out=torch.relu(self.fc3(x))
       return out

对比batch_size=64的全连接网络和le-net5卷积神经网络对MNIST数据集的效果。对数据结果进行数据可视化。
a11889d4691dff81826b764b8ad32ed8.png

结语

图像的数据是一个矩阵,也就是一个像素点的和它所在位置的上下左右像素点有很大的相关性,将像素矩阵flatten后,左右的位置关系保留了,但是上下的位置关系却被破坏了。

卷积神经网络和全连接神经网络的不同之处在于,卷积神经网络的输入是图像的原始矩阵,这样保留了图像的上下左右位置关系。

数据可视化的结果显示,fcn和Le-net5的精确度达到最好的效果,验证集和训练集的精确度高达99%,FCN这组数据有些过拟合,FCN存在参数过多,而使用CNN可以减少参数,降低过拟合,在 Le-net5模型搭建为了降低过拟合,使用了Dropout方法。在训练VGG过程中可以降低batch_size,也可通过Google Colab进行训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/186783.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32—串口

串口介绍 串行接口简称串口,也称串行通信接口或串行通讯接口(通常指COM接口),是采用串行通信方式的扩展接口。串行接口(Serial Interface)是指数据一位一位地顺序传送。其特点是通信线路简单,只…

.NET MAUI 安卓 UI 资源设置

本文主要介绍使用 MAUI 开发安卓应用时,如何更换和处理 UI 资源:应用名称,图标,主题配色,状态栏,闪屏。 文章目录1. 背景2. 资源设置2.1 项目创建2.2 应用名称2.3 应用图标2.4 应用闪屏2.5 沉浸式状态栏1. …

通用智能如何拥有生命的简单设计

如第一个图所示 是和环境交互的时候 行为交互时间和环境反馈时间T0 T1 还有行为消耗能量E0 环境反馈能量E1 如图有四种情况 其中反馈时间T1小于交互时间T0的任务是积极反馈 和打游戏一样要及时反馈才能提起兴趣 这个是整个行为交互过程中最小的记录单元 图二的每个元素都代表多…

STM32项目-STM32智能小车-电子设计大赛-STM32cubemx-STM32f103c8t6STM32串口通信-

记录项目的详细制作过程,所以笔记很长,图很多、很多图不好CSDN搬运, 我把笔记放网盘或者自己根据资料下载 笔记网盘下载: 链接:https://pan.baidu.com/s/1Mk2EVIha7Fpj4Xductg3Uw?pwdVCC1 提取码:VCC1 笔记CSDN下载:…

C++11 入门

作者:小萌新 专栏:C进阶 作者简介:大二学生 希望能和大家一起进步! 本篇博客简介:介绍C11的一些背景知识 本篇博客主要是讲解一些关键字 C11前言C11诞生简介列表初始化{}初始化关键字autodecltypenullptr范围forSTL的更…

技术开发117

技术开发117 业务内容: 半导体制造设备零件(阀门零件、管件)、汽车的各种功能部件(发动机、动力转向器、空调、刹车和传动系统)、 建筑和工业设备部件、电信设备的零件、气动特殊气缸、供热和制冷系统的零件、其他一…

CAD中怎么绘制攒尖屋顶?CAD设计攒尖屋顶技巧

在给排水CAD设计中,有些时候为了需要会在图纸中绘制攒尖屋顶,那么你知道CAD软件中怎么构造攒尖屋顶三维模型吗?其实很简单,浩辰CAD给排水软件中提供了实用的攒尖屋顶功能,下面就和小编一起来看看浩辰CAD给排水软件中CA…

Android RCLayout 圆角布局,支持边框,渐变色,渐变色方向等

RCLayout 圆角布局,支持边框,渐变色,渐变色方向等 支持布局 RcRelativeLayout RcLinearLayout RcFrameLayout RcConstraintLayout RcAbsoluteLayout RcTextView 引入 implementation com.github.IHoveYou:RCLayout:1.0.1 项目地址 链接: github 布局属性 <!-- 背景色/渐…

Java-基础-1.异常

一&#xff1a;异常架构 Error 类层次描述了 Java 运行时系统内部错误和资源耗尽错 误。这类错误是我们无法控制的&#xff0c;同时也是非常罕见的错误。所以在编程中&#xff0c;不去处理这类错误。Error 表明系统 JVM 已经处于不可恢复的崩溃状态中。我们不需要管他。 如:写代…

电商前台项目(五):完成加入购物车功能和购物车页面

Vue2项目前台开发&#xff1a;第五章一、加入购物车1.路由跳转前先发请求把商品数据给服务器&#xff08;1&#xff09;观察接口文档&#xff08;2&#xff09;写接口&#xff08;3&#xff09;dispatch调用接口传数据&#xff08;4&#xff09;判断服务器是否已经收到商品数据…

Spring-相关概念入门

Spring-相关概念&入门 2&#xff0c;Spring相关概念 2.1 初识Spring 在这一节&#xff0c;主要通过以下两个点来了解下Spring: 2.1.1 Spring家族 官网&#xff1a;https://spring.io&#xff0c;从官网我们可以大概了解到&#xff1a; Spring能做什么:用以开发web、微服…

六、附近商户,连续签到,UV统计

文章目录附近商户GEO的基本用法导入店铺数据到GEO实现附近商户功能签到BitMap的基本用法实现签到功能实现连续签到统计功能补充&#xff1a;Java中>>和>>>的区别UV统计HyperLogLog的基本用法测试百万数据的统计官方命令文档&#xff1a;https://redis.io/comman…

OpenGLES(一)——介绍

一、OpenGL介绍 OpenGL&#xff08;全写Open Graphics Library&#xff09;是指定义了一个跨编程语言、跨平台的编程接口规格的专业的图形程序接口。它用于三维图像&#xff08;二维的亦可&#xff09;&#xff0c;是一个功能强大&#xff0c;调用方便的底层图形库。     O…

六、创建Gitee仓库和提交代码

1、创建仓库 1.1、创建远程仓库 (1)登录Gitee.com&#xff0c;点击右上角 号&#xff0c;再点击新建仓库。 (2)填写仓库名称&#xff0c;设置公开(一般指开源项目)或者私有&#xff0c;其他默认(也可以根据自己需要选择) (3)这里要勾选设置模板&#xff0c;Readme文件。(如果…

Java I/O 流详解(Basic I/O)

目录 1、Java Basic I/O 中的字节流&#xff1a;Byte Streams 2、Java Basic I/O 中的字符流&#xff1a;Character Streams 3、Java Basic I/O 中的缓冲流&#xff1a;Buffered Streams 4、Java Basic I/O 中的打印流&#xff1a;PrintStream &#xff08;数据扫描和格式化…

网易二面:CPU狂飙900%,该怎么处理?

说在前面 社群一位小伙伴面试了 网易&#xff0c;遇到了一个 性能类的面试题&#xff1a; CPU飙升900%&#xff0c;该怎么处理&#xff1f; 可惜的是&#xff0c;以上的问题&#xff0c;这个小伙没有回答理想。 最终&#xff0c;导致他网易之路&#xff0c;终止在二面&…

【蓝桥杯】Python字符串处理和应用

前言&#xff1a; 本文侧重于通过实战训练来提高字符串的处理能力&#xff0c;可以先行学习一下我之前的文章&#xff1a;蓝桥杯Python快速入门&#xff08;4&#xff09; &#xff0c;学习完基础知识再来刷题才会事半功倍&#xff01; 字符串处理 # 字符串切片 str1"1…

优秀码农选择对象详细指南,看完记得要实战噢

2023年了&#xff0c;你是否已到了法定年纪&#xff0c;那么这一篇优秀码农选择对象的详细指南&#xff0c;你一定用得到&#xff0c;看完记得感谢狗哥哦&#xff01; 目录 一、对于婚姻先来思考这么几条 1. 太快决定结婚&#xff1f; 2. 一方或双方急于结婚&#xff1f; 3.…

【go语言入门教程】——1. go语言介绍及安装

目 录1. go 语言简介2. go 语言安装2.1 下载安装包2.2 安装 go2.3 验证安装结果3. 使用 VS Code 运行 go 程序1. go 语言简介 go的产生 go 是一个开源的编程语言&#xff0c;它能让构造简单、可靠且高效的软件变得容易。 go 是从 2007 年末由 Robert Griesemer, Rob Pike, Ken…

Linux系统常见问题总结(持续更新)

目录一&#xff0c;vim安装与设置1&#xff0c;安装2&#xff0c;配置二&#xff0c;Found a swap file by the name三&#xff0c;docker启动失败&#xff1a;Job for docker.service failed because the control process exited with error四&#xff0c;docker-compose安装r…