[PyTorch][chapter 41][卷积网络实战-LeNet5]

news2024/10/6 2:26:46

前言

    这里结合前面学过的LeNet5 模型,总结一下卷积网络搭建,训练的整个流程

目录:

    1: LeNet-5 

    2:    卷积网络总体流程

    3:  代码


一  LeNet-5

      LeNet-5是一个经典的深度卷积神经网络,由Yann LeCun在1998年提出,旨在解决手写数字识别问题,被认为是卷积神经网络的开创性工作之一。该网络是第一个被广泛应用于数字图像识别的神经网络之一,也是深度学习领域的里程碑之一

参数

输出shape

输入层

[batch,channel,32,32]

  C1(卷积层) 

6@5x5 卷积核 ,stride=1 ,padding=0

[batch,6,28,28]

  S2(池化层) 

kernel_size=2,stride=2,padding=0

[batch,6,14,14]

  C3(卷积层)


 

16@5x5 卷积核,stride=1,padding=0

[batch,16,10,10]

 S4(池化层) 

kernel_size=2,stride=2,padding=0

[batch,16,5,5]

  C5(卷积层)


 

120@5x5卷积核,stride=1padding=0

[batch,120,1,1]

 F6-全连接层 

nn.Linear(in_features=120,  out_features=84)

[batch,120]

 Output-全连接层 

nn.Linear(in_features=120,  out_features=10)

[batch,10]


二 卷积网络的总体流程

     

2.1、nn.Module建立神经网络模型
          model = LeNet5()

          

2.2、建立此网络的可学习的参数,以及更新规则
       optimizer = optim.Adam(model.Parameters(), lr=1e-3) 

        梯度更新的公式

2.3、构建损失函数

        损失函数模型
        criteon = nn.CrossEntropyLoss() 

2.4    前向传播

      logits = model(x)

       根据现有的权重系数,预测输出

2.5   反向传播

      optimizer.zero_grad() #先将梯度归零w_grad
      loss.backward()       #反向传播计算得到每个参数的梯度值w_grad

      通过当前的loss ,计算梯度

2.6   利用optim 更新权重系数

       optimizer.step() #更新权重系数W

       利用优化器更新权重系数
          

        


  三  代码 

# -*- coding: utf-8 -*-
"""
Created on Thu Jun 15 14:32:54 2023

@author: chengxf2
"""
import torch
from torch import nn
from torch.nn import functional as F 
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim 
import ssl


class  LeNet5(nn.Module):
    
    
    """
    for cifar10 dataset
    """
    
    def __init__(self):
        
        super(LeNet5, self).__init__()
        
        self.conv_unit = nn.Sequential(
            
            #卷积层1 x:[b,3,32,32] => [b,6, 30,30]
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5,stride=1,padding=0),
            #池化层1
            nn.MaxPool2d(kernel_size=2,stride=2, padding =0),
            
            #卷积层2  
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5,stride=1, padding=0),
            #池化层2
            nn.MaxPool2d(kernel_size=2,stride=2, padding =0)
            #x:[b,16,5,5]
            )
        
        self.flatten = nn.Flatten(start_dim =1, end_dim = -1)
        
        self.fc_unit = nn.Sequential(
              nn.Linear(in_features=16*5*5, out_features=120),
              nn.ReLU(),
              nn.Linear(in_features=120, out_features=84),
              nn.ReLU(),
              nn.Linear(in_features=84, out_features=10)
              )
       
        
    
    def forward(self, x):
        '''
        

        Parameters
        ----------
        x : 
            [batch,channel=3, width=32, height=32].

        Returns
        -------
        out : 
            DESCRIPTION.

        '''
        #[b,3,32,32] =>[b,16,5,5]
        out = self.conv_unit(x)
        
        #print("\n 卷积层输出 :",out.shape)
        #[b,16,5,5]=>[b,16*5*5]
        out = self.flatten(out)
        #print("\n flatten层输出 :",out.shape)
        #[b,400]=>[b,10]
        out = self.fc_unit(out)
        #print("\n 全连接层输出 :",out.shape)
        
        #pred = F.softmax(out,dim=1)
        return out
            

            
def train():
    
    x = torch.randn(8,3,32,32)
    net = LeNet5()
    
    out = net(x)
    
    print(out.shape)
               

def main():
    
    batchSize =32 
    maxIter = 10
    dataset_trans = transforms.Compose([transforms.ToTensor(),transforms.Resize((32,32))]) 
    imgDir='./data'
    print("\n ---beg----")
    cifar_train = datasets.CIFAR10(root= imgDir,train=True, transform= dataset_trans,download =False) 
    cifar_test =  datasets.CIFAR10(root= imgDir,train=False,transform= dataset_trans,download =False) 
    train_data = DataLoader(cifar_train, batch_size=batchSize,shuffle=True)
    test_data = DataLoader(cifar_test, batch_size=batchSize,shuffle=True)
   
    print("\n --download finsh---")
    device = torch.device('cuda')
    # DataLoader迭代产生训练数据提供给模型 
    model = LeNet5().to(device)
    
    criteon = nn.CrossEntropyLoss() #前向传播计算loss
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) #反向传播
    
    for epoch in range(maxIter):
       
       for batchindex,(x,label) in enumerate(train_data):
          
          #x: [b,3,32,32]
          #label: [b]
          x,label = x.to(device),label.to(device)
          
          logits = model(x)
          loss = criteon(logits, label)
          
          #backpop
          optimizer.zero_grad()
          loss.backward()
          optimizer.step() #更新梯度
          
          if batchindex%500 ==0:
              print('batchindex {}, loss {}'.format(batchindex, loss.item()))
    
       model.eval()
       total_correct =0.0
       total_num = 0.0
       with torch.no_grad():
           
           for batchindex,(x,label) in enumerate(test_data):
               x,label = x.to(device),label.to(device)
               logits = model(x)
               pred = logits.argmax(dim=1)
               
               total_correct += torch.eq(pred, label).float().sum()
               total_num += x.size(0)
           acc = total_correct/total_num
           print('\n epoch: {} ,acc: {}  total_num: {}'.format(epoch, acc, total_num))
           
           
           

            
          
          
      

    
if __name__ == "__main__":
    
     main()
    
    
    

因为不是灰度图,训练10轮,acc 只有 epoch: 9 ,acc: 0.6310999989509583  total_num: 10000.0

可以把卷积核调整小一点

参考:

https://mp.csdn.net/mp_blog/creation/editor/131209651

课时79 卷积神经网络训练_哔哩哔哩_bilibili

课时77 卷积神经网络实战-1_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/652544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

zabbix-agent安装

1.CentOS release 5 1-1.centos5 32位 [rootLV zabbix]# cat /etc/redhat-release CentOS release 5 (Final) [rootLV zabbix]# uname -a Linux LV 2.6.18-53.el5xen #1 SMP Mon Nov 12 03:26:12 EST 2007 i686 i686 i386 GNU/Linux确定了系统centos5 32位rpm方式安装&#…

Ubuntu18.04离线安装redis

因需要安装redis的服务器无法连接互联网,所以需要离线安装。首先需要下载redis的安装包,之后进行安装,在安装之前需要保证gcc,g,make等依赖包已经安装。 1. 安装gcc等依赖包 依赖包安装请参考: Ubuntu18…

CI570 3BSE001440R1需要电流显示和就地/远传控制

​ CI570 3BSE001440R1需要电流显示和就地/远传控制 CI570 3BSE001440R1需要电流显示和就地/远传控制 如果变频器与通讯方式与DCS系统连接,则只需要计算1个通讯点,不需要计算其他点数。 (6)如DCS系统外接电磁阀、指示灯、接触器等…

物联网云平台数据存储方案,这次我终于找对了

《高并发系统实战派》-- 你值得拥有 文章目录 物联网云平台存储概述为什么要做存储?存储的意义在哪里?数据存储方案设计存储数据库选型需要考虑的因素数据库选型结构化数据半结构化数据非结构化数据 案例分析第一颗栗子第二颗栗子第三颗栗子第四颗栗子 …

Web安全:vulhub 靶场搭建.(各种漏洞环境集合,一键搭建漏洞测试靶场)

Web安全:vulhub 靶场搭建. Vulhub是一个面向大众的开源漏洞靶场,无需docker知识,简单执行两条命令即可编译、运行一个完整的漏洞靶场镜像。让漏洞复现变得更加简单,让安全研究者更加专注于漏洞原理本身. 目录: Web安…

DLL修复工具下载,解决DLL文件问题的方法

在计算机应用程序中,我们经常会遇到一些错误提示,如“找不到.dll文件”或“无法加载.dll文件”。这些问题通常是由于缺少或损坏的DLL文件造成的。为了解决这些问题,我们可以借助DLL修复工具来修复和恢复DLL文件。本文将介绍什么是DLL文件&…

C# 自动更新(基于FTP)

效果 启动软件后,会自动读取所有的 FTP 服务器文件,然后读取本地需要更新的目录,进行匹配,将 FTP 服务器的文件同步到本地 Winform 界面 一、前言 在去年,我写了一个 C# 版本的自动更新,这个是根据配置文…

黑盒、白盒、灰盒,如何选择合适的模糊测试工具?

在软件开发和安全领域,模糊测试是一种常用技术,用于发现应用程序或系统中的潜在漏洞和安全弱点。选择不同的模糊测试方法将极大地影响测试的有效性和效率。本文将比较对比黑盒、白盒和灰盒模糊测试的特点和优势并提供选型指导。 模糊测试的分类 黑盒模糊…

人工智能带来的利好将继续推动Palantir股价反弹

来源:猛兽财经 作者:猛兽财经 Palantir的人工智能产品已经获得行业认可 Palantir(PLTR)成立于2003年,是一家专注于数据分析和人工智能的知名软件公司。其尖端产品已应用到了政府机构、企业和非营利组织等各个行业&…

华为OD机试 Java 实现【输出单向链表中倒数第k个结点】【LeetCode练习题】,附详细解题思路

一、题目描述 输入一个单向链表,输出该链表中倒数第k个结点,链表的倒数第1个结点为链表的尾指针。 链表结点定义如下: class ListNode{int value;ListNode next;public ListNode(){}public ListNode(

解决H5在native中键盘弹起影响页面交互

您好,如果喜欢我的文章,可以关注我的公众号「量子前端」,将不定期关注推送前端好文~ 问题描述 在native中拉起键盘再收回,滚动列表实际距离发生变化,被键盘一起弹上去了(我这里大约是400px的样子&#xf…

使用yolox训练自己的数据集并测试

1.首先给出yolox原模型的下载地址: ​​​​​​https://github.com/bubbliiiing/yolox-pytorch 百度网盘链接给出自己完整的模型(包括数据集以及权重文件): 链接:https://pan.baidu.com/s/1JNjB42u9eGNhRjr1SfD_Tw 提取码&am…

Redis和Redis可视化管理工具的下载和安装

文章目录 Redis 简介一,Redis 下载二,Redis 安装三,Redis 配置四,Redis 启动 Redis-Desktop-Manager 简介一,Redis-Desktop-Manager 下载二,Redis-Desktop-Manager 安装三,Redis-Desktop-Manage…

深度学习笔记之Transformer(三)自注意力机制

深度学习笔记之Transformer——自注意力机制 引言回顾:缩放点积注意力机制自注意力机制自注意力机制与 RNN,CNN \text{RNN,CNN} RNN,CNN的对比简单介绍:卷积神经网络处理序列信息的原理从计算复杂度的角度观察 位置编码 引言 上一节对注意力分数 ( Atte…

关于前端Vue脚手架的完整搭建

创建脚手架 在VSC中打开命令行&#xff0c;输入如下命令可以用于创建脚手架 Vue create <项目名称>会出现如下选项&#xff1a; 前面是选项的名称&#xff0c;括号中的是选项包含有&#xff1a; 1、Vue的版本 2、babel是用于将高版本的js转化成为低版本的js&#xff0…

SSM 整合案例

Ssm整合 注意事项 Spring mvcSpringMybatisMySQL项目结构&#xff0c;以web项目结构整合&#xff08;还有maven、gradle&#xff09;整合日志、事务一个项目中Idea 2019、JDK12、Tomcat9、MySQL 5.5 项目结构 D:\java\Idea\Idea_spacework\SSMhzy&#xff1a;不会就去找项目…

chatgpt赋能python:Python怎么筛选奇数

Python怎么筛选奇数 Python是一种高级编程语言&#xff0c;既具有面向对象编程的特点&#xff0c;又可以进行函数式编程。Python的语法简洁、清晰&#xff0c;非常适合初学者学习。在Python中&#xff0c;筛选奇数的方法非常简单&#xff0c;本文将介绍Python中筛选奇数的方法…

人机交互学习-5 交互式系统的需求

交互式系统的需求 需求是什么需求需求活动 产品特性用户特性体验水平差异新手用户专家用户中间用户 年龄差异老年人儿童 文化差异健康差异 用户建模人物角色人物角色的作用人物角色的构造错误观点人物角色基于问题举例注意事项 建模过程 需求获取、分析和验证观察场景人物角色场…

爬虫数据是如何收集和整理的?

爬虫数据的收集和整理通常包括以下步骤&#xff1a; 确定数据需求&#xff1a;确定要收集的信息类型、来源和范围。 网络爬取&#xff1a;使用编程工具&#xff08;如Python的Scrapy、BeautifulSoup等&#xff09;编写爬虫程序&#xff0c;通过HTTP请求获取网页内容&#xff…

网卡命名规则和网卡变动结论

net.ifnames0 biosdevname0 插卡前状态&#xff1a; 插卡后状态&#xff1a; 结论&#xff1a;明显eth0 MAC地址从00:0d:48:94:10:fc 变更为 c0:33:da:10:31:ff。该方法eth0实际对应的网口发生了变动。 net.ifname1 插卡前状态&#xff1a; 插卡后状态&#xff1a; 查看…