经典卷积神经网络 - LeNet

news2024/12/23 16:40:35

image-20231022164234916

该模型用于手写的数字识别。
LeNet模型包含了多个卷积层和池化层,以及最后的全连接层用于分类。其中,每个卷积层都包含了一个卷积操作和一个非线性激活函数,用于提取输入图像的特征。池化层则用于缩小特征图的尺寸,减少模型参数和计算量。全连接层则将特征向量映射到类别概率上。

MNISt数据集

50000个训练数据,10000个测试数据。图像大小为28x28,共10类(0~9)。

  • LeNet是早期成功的神经网络
  • 先使用卷积层来学习图片空间信息
  • 然后使用全连接层来转换到类别空间

对于padding

通用的卷积时padding 的选择

如卷积核宽高为3时 padding 选择1

如卷积核宽高为5时 padding 选择2

如卷积核宽高为7时padding选择3

至于选择填充多少像素,通常有两个选择,分别叫做Valid卷积和Same卷积。

Valid卷积意味着不填充,这样的话,如果你有一个 n × n n\times n n×n的图像,用一个 f × f f\times f f×f的过滤器卷积,它将会给你一个 ( n − f + 1 ) × ( n − f + 1 ) (n-f+1)\times (n-f+1) (nf+1)×(nf+1)维的输出。例如,有一个6×6的图像,通过一个3×3的过滤器,得到一个4×4的输出。

Same卷积意味你填充后,你的输出大小和输入大小是一样的。根据这个公式 n − f + 1 n-f+1 nf+1,当你填充 p p p个像素点,n就变成了 n + 2 p n+2p n+2p,最后公式变为 n + 2 p − f + 1 n+2p-f+1 n+2pf+1。因此如果你有一个 n × n n\times n n×n的图像,用 p p p个像素填充边缘,输出的大小就是这样的 ( n + 2 p − f + 1 ) × ( n + 2 p − f + 1 ) (n+2p−f+1)\times (n+2p−f+1) (n+2pf+1)×(n+2pf+1)。如果你想让 ( n + 2 p − f + 1 ) = n (n+2p−f+1)=n (n+2pf+1)=n的话,使得输出和输入大小相等,如果你用这个等式求解 p p p,那么 p = ( f − 1 ) / 2 p=(f-1)/2 p=(f1)/2。所以当 f f f是一个奇数的时候,只要选择相应的填充尺寸,你就能确保得到和输入相同尺寸的输出。

代码实现

LeNet(LeNet-5)由两个部分组成:卷积编码器和全连接层密集块。

model.py

from torch import nn

class Reshape(nn.Module):
    def forward(self,x):
        return x.reshape((-1,1,28,28))

class MyLeNet(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # 假如sigmoid激活函数后 损失不下降
        self.model = nn.Sequential(
            Reshape(),
            nn.Conv2d(1,6,kernel_size=5,padding=2),
            nn.Sigmoid(),
            nn.AvgPool2d(2,stride=2),
            nn.Conv2d(6,16,kernel_size=5),
            nn.Sigmoid(),
            nn.AvgPool2d(2,stride=2),
            nn.Flatten(),
            nn.Linear(16*5*5,120),
            nn.Linear(120,84),
            nn.Linear(84,10)
        )

    def forward(self,x):
        return self.model(x)

train.py

# 扫描数据次数
epochs = 20
# 分组大小
batch = 64
# 学习率
learning_rate = 0.05
# 训练次数
train_step = 0
# 测试次数
test_step = 0


# 定义图像转换
transform = transforms.Compose([
    transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.MNIST(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.MNIST(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))

# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# 创建网络
net = MyLeNet()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)

writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):
    print("-------------------第 {} 轮训练开始-------------------".format(epoch))
    net.train()
    for data in train_dataloader:
        train_step = train_step + 1
        images,targets = data
        images = images.to(device)
        targets = targets.to(device)
        outputs = net(images)
        loss_out = loss(outputs,targets)
        optimizer.zero_grad()
        loss_out.backward()
        optimizer.step()

        if train_step%100==0:
            writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)
            print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))

    # 测试
    net.eval()
    total_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            test_step = test_step + 1
            images, targets = data
            images = images.to(device)
            targets = targets.to(device)
            outputs = net(images)
            loss_out = loss(outputs, targets)
            total_loss = total_loss + loss_out
            accuracy = (targets == torch.argmax(outputs,dim=1)).sum()
            total_accuracy = total_accuracy + accuracy
        # 计算精确率
        print(total_accuracy)
        accuracy_rate = total_accuracy / test_size

        print("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))
        print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))
        writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)
        writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)
    torch.save(net,"./model/net_{}.pth".format(epoch+1))
    print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1123805.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Go运算操作符全解与实战:编写更高效的代码!

本文全面探讨了Go语言中的各类运算操作符,从基础的数学和位运算到逻辑和特殊运算符。文章在深入解析每一种运算操作符的工作原理、应用场景和注意事项,以帮助开发者编写更高效、健壮和可读的Go代码。 简介 Go语言,作为一种现代的编程语言&am…

Python实战小项目分享

Python实战小项目包括网络爬虫、数据分析和可视化、文本处理、图像处理、聊天机器人、任务管理工具、游戏开发和网络服务器等。这些项目提供了实际应用场景和问题解决思路,可以选择感兴趣的项目进行实践,加深对Python编程的理解和掌握。在实践过程中&…

CRM销售管理系统是如何进行数据分析的

数据分析可以帮助销售人员挖掘潜在问题,知晓哪些渠道可以带来更多的客户,为日常的销售工作提供科学依据。当然,要做好数据分析不是一件简单的事,利用好销售管理系统是关键。那么CRM销售管理系统是如何进行数据分析的呢&#xff1f…

TCP网络通信

package TCP1;//完成TCP通信的 实现发1收1import java.io.DataOutputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.net.InetAddress; import java.net.Socket;public class Client {public static void main(String[] args)throws Ex…

计算属性和侦听属性以及方法有什么区别,本文以计算一个数组中所有偶数的和为例

计算属性(computed)是Vue中的一个特殊属性,它根据依赖的数据进行计算,并返回计算结果。计算属性的值会根据其相关依赖项的变化而自动更新,类似于一个响应式的缓存。计算属性可以用来处理一些复杂的逻辑计算,避免在模板中编写过多的…

asp.net网球馆计费管理系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net网球馆计费管理系统是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c#语 言开发 aspnet网球馆计费管理系统1 二、…

windows系统mysql服务启动失败

​ 原因 电脑重启navicat连接mysql失败,在电脑-管理-服务没有mysql服务 解决方案 找到mysql的安装目录进入bin目录 执行mysqld --install 进行重新安装 提示服务安装成功 net start mysql mysql 启动成功 ​

java编译时指定classpath

说明 Java编译时可以通过选项--class-path <path>&#xff0c;或者 -classpath <path>&#xff0c;或者-cp <path>来指定查找用户类文件、注释程序处理程序、或者源文件的位置。这个设置覆盖CLASSPATH环境变量的设置。如果没有设置-sourcepath&#xff0c;那…

【Gensim概念】01/3 NLP玩转 word2vec

第一部分 词法 一、说明 Gensim是一种Python库&#xff0c;用于从文档集合中提取语义主题、建立文档相似性模型和进行向量空间建模。它提供了一系列用于处理文本数据的算法和工具&#xff0c;包括主题建模、相似性计算、文本分类、聚类等。在人工智能和自然语言处理领域&…

简历上的哪些内容,才是面试官眼中的干货?

在准备面试时&#xff0c;简历是我们的敲门砖&#xff0c;它是我们与面试官沟通的第一步。因此&#xff0c;简历的内容对我们的求职成功至关重要。那么&#xff0c;简历上哪些内容才是面试官眼中的干货呢&#xff1f; 第一&#xff0c;简历的格式和排版应该整洁、清晰、易读。简…

PyQt项目实战1

转载 pyqt5:利用QFileDialog从本地选择图片\文本文档显示到label、保存图片\label文本到本地&#xff08;附代码&#xff09;_pyqt5中qfiledialog.getopenfileurl-CSDN博客https://blog.csdn.net/tensixchuan/article/details/1057178331、QtDesigner的控件摆设完成后&#xf…

告别杂音干扰,享受纯净通话:华为Mate 60 Pro降噪功能体验

作为一名销售&#xff0c;我经常需要使用手机跟客户进行通话。但是&#xff0c;有时候环境嘈杂或者对方的声音不够清晰&#xff0c;让我感到非常烦恼。好在我最近入手了一款华为Mate 60 Pro手机&#xff0c;发现通话功能也有惊喜新升级&#xff0c;它带来的降噪功能让我重新定义…

先后在影酷/传祺E9/昊铂GT量产交付,这家ADAS厂商何以领跑

智能泊车赛道正在迎来黄金增长期&#xff0c;以魔视智能为代表的玩家正在驶入大规模量产的“快车道”。 继在广汽传祺影酷、广汽传祺 E9实现规模化量产交付之后&#xff0c;魔视智能的Magic Parking智能泊车系列解决方案再度在广汽埃安旗下高端智能轿跑——昊铂GT上面实现量产…

基于YOLOv5[n/s/m/l/x]全系列参数模型开发构建小麦麦穗智能化精准检测识别计数系统

在前文中我们已经开发实践了小麦颗粒和小麦麦穗的检测&#xff0c;感兴趣可以自行移步阅读即可&#xff1a; 《基于YOLOv5[n/s/m/l/x]全系列参数模型开发构建小麦麦穗颗粒智能化精准检测识别计数系统》 《基于轻量级yolov5nCBAM开发构建全球小麦麦穗智能检测计数系统》 在上…

LiveGBS流媒体平台GB/T28181功能-报警预案配置告警触发报警时截图及录像摄像头通过GB28181上报报警

LiveGBS报警预案配置告警触发报警时截图及录像摄像头通过GB28181上报报警 1、报警信息1.1、报警查询1.2、配置开启报警订阅1.2.1、国标设备编辑1.2.2、选择开启报警订阅 1.3、配置摄像头报警1.3.1、配置摄像头报警通道ID1.3.2、配置摄像头开启侦测1.3.3、尝试触发摄像头报警1.3…

【TES605】基于Virtex-7 FPGA的高性能实时信号处理平台

板卡概述 TES605是一款基于Virtex-7 FPGA的高性能实时信号处理平台&#xff0c;该平台采用1片TI的KeyStone系列多核DSP TMS320C6678作为主处理单元&#xff0c;采用1片Xilinx的Virtex-7系列FPGA XC7VX690T作为协处理单元&#xff0c;具有2个FMC子卡接口&#xff0c;各个处理节…

某网站cookies携带https_ydclearance获取正文

1、url aHR0cHM6Ly9iYnMuNTFjcmVkaXQuY29tL3RocmVhZC03ODI0OTAzLTEtMS5odG1s2、抓包 根据抓包返回的两个请求进行访问&#xff0c;发现获取正文需cookies携带https_ydclearance cookies {https_ydclearance: 6973fc7d30e4fe01c1bdde9f-ff5e-4d22-bfc2-00e5ab7769b7-16980360…

【tg】8: Manager的主要功能

Manager 提供的是media thread 说明media thread 是主线程&#xff0c; 而 mediamgr里是worker threadnetworkmgr是network thread了。 Manager 的功能重要&#xff0c;但是特别短 G:\CDN\P2P-DEV\tdesktop-offical\Telegram\ThirdParty\tgcalls\tgcalls\Manager.cpp class…

在线零售多用户多门店连锁商城系统

在线零售多用户商城系统和多门店连锁商城系统的核心都是线上线下相结合的&#xff0c;线上和线下结合&#xff0c;一体化是在线新零售多用户商城系统发展的趋势&#xff0c;现在移动互联网时代&#xff0c;越来越多的传统企业&#xff0c;如&#xff1a;连锁店铺&#xff0c;连…

Unity | Image 自定义顶点数据实现圆角矩形

1 圆角方案简介 UGUI 中的 Image 实现圆角效果通常有三种方式&#xff0c;Mask、Shader以及自定义顶点数据&#xff0c;相比于前两者&#xff0c;自定义顶点数据的使用方式更加灵活&#xff0c;同时可以减少 DrawCall&#xff0c;但是会增加顶点及三角形数量。最终实现方案可根…