pytorch-复现经典深度学习模型-LeNet5

news2024/9/21 14:45:29

Neural Networks

  • 使用torch.nn包来构建神经网络。nn包依赖autograd包来定义模型并求导。 一个nn.Module包含各个层和一个forward(input)方法,该方法返回output

  • 一个简单的前馈神经网络,它接受一个输入,然后一层接着一层地传递,最后输出计算的结果。

    • 在这里插入图片描述
  • 神经网络的典型训练过程如下:

      1. 定义包含一些可学习的参数(或者叫权重)神经网络模型;
      2. 在数据集上迭代;
      3. 通过神经网络处理输入;
      4. 计算损失(输出结果和正确值的差值大小);
      5. 将梯度反向传播回网络的参数;
      6. 更新网络的参数,主要使用如下简单的更新原则: weight = weight - learning_rate * gradient

开始定义一个Lenet5网络:

  • import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            # 1 input image channel, 6 output channels, 5x5 square convolution
            # kernel
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            # an affine operation: y = Wx + b
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
        def forward(self, x):
            # Max pooling over a (2, 2) window
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            # If the size is a square you can only specify a single number
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = x.view(-1, self.num_flat_features(x))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
        def num_flat_features(self, x):
            size = x.size()[1:]  # all dimensions except the batch dimension
            num_features = 1
            for s in size:
                num_features *= s
            return num_features
    net = Net()
    print(net)
    
  • Net(
      (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
      (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
      (fc1): Linear(in_features=400, out_features=120, bias=True)
      (fc2): Linear(in_features=120, out_features=84, bias=True)
      (fc3): Linear(in_features=84, out_features=10, bias=True)
    )
    
  • 在模型中必须要定义 forward 函数,backward 函数(用来计算梯度)会被autograd自动创建。 可以在 forward 函数中使用任何针对 Tensor 的操作。torch.nn 只支持小批量输入。整个 torch.nn 包都只支持小批量样本,而不支持单个样本。 例如,nn.Conv2d 接受一个4维的张量, 每一维分别是sSamples * nChannels * Height * Width(样本数*通道数*高*宽)。 如果你有单个样本,只需使用 input.unsqueeze(0) 来添加其它的维数net.parameters()返回可被学习的参数(权重)列表和值

  • params = list(net.parameters())
    print(len(params))
    for i in range(len(params)):
        print(f"第{i}层需要学习参数的参数量矩阵大小:",params[i].size())  )  
    
  • 100层需要学习参数的参数量矩阵大小: torch.Size([6, 1, 5, 5])1层需要学习参数的参数量矩阵大小: torch.Size([6])2层需要学习参数的参数量矩阵大小: torch.Size([16, 6, 5, 5])3层需要学习参数的参数量矩阵大小: torch.Size([16])4层需要学习参数的参数量矩阵大小: torch.Size([120, 400])5层需要学习参数的参数量矩阵大小: torch.Size([120])6层需要学习参数的参数量矩阵大小: torch.Size([84, 120])7层需要学习参数的参数量矩阵大小: torch.Size([84])8层需要学习参数的参数量矩阵大小: torch.Size([10, 84])9层需要学习参数的参数量矩阵大小: torch.Size([10])
    
  • 测试随机输入32×32。 注:这个网络(LeNet)期望的输入大小是32×32,如果使用MNIST数据集来训练这个网络,请把图片大小重新调整到32×32。

  • input = torch.randn(1, 1, 32, 32)
    out = net(input)
    print(out)
    
  • tensor([[ 0.0501,  0.1101, -0.0294,  0.0030,  0.0629,  0.0379,  0.0860,  0.0104,
              0.1108,  0.0916]], grad_fn=<AddmmBackward0>)
    
  • 将所有参数的梯度缓存清零,然后进行随机梯度的的反向传播:

  • net.zero_grad()
    out.backward(torch.randn(1, 10).data)
    print(out)
    
  • tensor([[-0.0275, -0.0630, -0.0253, -0.0811,  0.0872, -0.0282,  0.0926,  0.1020,
             -0.1034,  0.0727]], grad_fn=<AddmmBackward0>)
    
  • torch.Tensor:一个用过自动调用 backward()实现支持自动梯度计算的 多维数组 , 并且保存关于这个向量的梯度 w.r.t.

  • nn.Module:神经网络模块。封装参数、移动到GPU上运行、导出、加载等。

  • nn.Parameter:一种变量,当把它赋值给一个Module时,被 自动 地注册为一个参数。

  • autograd.Function:实现一个自动求导操作的前向和反向定义,每个变量操作至少创建一个函数节点,每一个Tensor的操作都回创建一个接到创建Tensor编码其历史 的函数的Function节点。

损失函数

  • 一个损失函数接受一对 (output, target) 作为输入,计算一个值来估计网络的输出和目标值相差多少。nn包中有很多不同的损失函数。 nn.MSELoss是一个比较简单的损失函数,它计算输出和目标间的均方误差, 例如:

  • output = net(input)
    target = torch.randn(10)  # 随机值作为样例
    target = target.view(1, -1)  # 使target和output的shape相同
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    print(loss)
    
  • tensor(0.8873, grad_fn=<MseLossBackward0>)
    
  • 反向过程中跟随loss , 使用它的 .grad_fn 属性,将看到如下所示的计算图。

  • input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
          -> view -> linear -> relu -> linear -> relu -> linear
          -> MSELoss
          -> loss
    
  • 当我们调用 loss.backward()时,整张计算图都会 根据loss进行微分,而且图中所有设置为requires_grad=True的张量 将会拥有一个随着梯度累积的.grad 张量。

反向传播

  • 调用loss.backward()获得反向传播的误差。但是在调用前需要清除已存在的梯度,否则梯度将被累加到已存在的梯度。现在,我们将调用loss.backward(),并查看conv1层的偏差(bias)项在反向传播前后的梯度。

  • net.zero_grad()     # 清除梯度
    print('conv1.bias.grad before backward')
    print(net.conv1.bias.grad)
    loss.backward()
    print('conv1.bias.grad after backward')
    print(net.conv1.bias.grad)
    
  • conv1.bias.grad before backward
    tensor([0., 0., 0., 0., 0., 0.])
    conv1.bias.grad after backward
    tensor([-0.0136,  0.0067,  0.0111,  0.0210,  0.0111, -0.0072])
    

更新权重

  • 当使用神经网络是想要使用各种不同的更新规则时,比如SGD、Nesterov-SGD、Adam、RMSPROP等,PyTorch中构建了一个包torch.optim实现了所有的这些规则。 使用它们非常简单:

  • import torch.optim as optim
    # create your optimizer
    optimizer = optim.SGD(net.parameters(), lr=0.01)
    # in your training loop:
    for i in range(10):
        optimizer.zero_grad()   # zero the gradient buffers
        output = net(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()    # Does the update
        print(f"第{i}轮训练:",loss)
    
  • 0轮训练: tensor(0.8873, grad_fn=<MseLossBackward0>)1轮训练: tensor(0.8725, grad_fn=<MseLossBackward0>)2轮训练: tensor(0.8582, grad_fn=<MseLossBackward0>)3轮训练: tensor(0.8447, grad_fn=<MseLossBackward0>)4轮训练: tensor(0.8309, grad_fn=<MseLossBackward0>)5轮训练: tensor(0.8173, grad_fn=<MseLossBackward0>)6轮训练: tensor(0.8035, grad_fn=<MseLossBackward0>)7轮训练: tensor(0.7897, grad_fn=<MseLossBackward0>)8轮训练: tensor(0.7764, grad_fn=<MseLossBackward0>)9轮训练: tensor(0.7623, grad_fn=<MseLossBackward0>)
    

035, grad_fn=)
第7轮训练: tensor(0.7897, grad_fn=)
第8轮训练: tensor(0.7764, grad_fn=)
第9轮训练: tensor(0.7623, grad_fn=)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/383841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营day47 |动态规划 198打家劫舍 213打家劫舍II 337打家劫舍III

day47198.打家劫舍1.确定dp数组&#xff08;dp table&#xff09;以及下标的含义2.确定递推公式3.dp数组如何初始化4.确定遍历顺序5.举例推导dp数组213.打家劫舍II情况一&#xff1a;考虑不包含首尾元素情况二&#xff1a;考虑包含首元素&#xff0c;不包含尾元素情况三&#x…

网络技术|网络地址转换与IPv6|路由设计基础|4

对应讲义——p6 p7NAT例题例1解1例2解2例3解3例4解4一、IPv6地址用二进制格式表示128位的一个IPv6地址&#xff0c;按每16位为一个位段&#xff0c;划分为8个位段。若某个IPv6地址中出现多个连续的二进制0&#xff0c;可以通过压缩某个位段中的前导0来简化IPv6地址的表示。例如…

1月奶粉电商销售数据榜单:销售额约20亿,高端化趋势明显

鲸参谋电商数据监测的2023年1月份京东平台“奶粉”品类销售数据榜单出炉&#xff01; 根据鲸参谋数据显示&#xff0c;1月份京东平台上奶粉的销量约675万件&#xff0c;销售额约20亿元&#xff0c;环比均下降19%左右。与去年相比&#xff0c;整体也下滑了近34%。可以看出&#…

真无线耳机哪个牌子好用?2023便宜好用的无线耳机推荐

蓝牙耳机经过近几年的快速发展&#xff0c;变得越来越普及&#xff0c;并且在一些性能上也做得越来越好。那么&#xff0c;真无线耳机哪个牌子好用&#xff1f;下面&#xff0c;我来给大家推荐几款便宜好用的无线耳机&#xff0c;可以参考一下。 一、南卡小音舱蓝牙耳机 参考…

Nuxt 3.0 全栈开发:五种渲染模式的差异和使用场景全解析

Nuxt 3.0 全栈开发 - 杨村长 - 掘金小册核心知识 工程架构 全栈进阶 项目实战&#xff0c;快速精通 Nuxt3 开发&#xff01;。「Nuxt 3.0 全栈开发」由杨村长撰写&#xff0c;299人购买https://s.juejin.cn/ds/S6p7MVo/ 前面我们提到过 Nuxt 能够满足我们更多开发场景的需求…

IGKBoard(imx6ull)-I2C接口编程之SHT20温湿度采样

文章目录1- 使能开发板I2C通信接口2- SHT20硬件连接3- 编码实现SHT20温湿度采样思路&#xff08;1&#xff09;查看sht20从设备地址&#xff08;i2cdetect&#xff09;&#xff08;2&#xff09;获取数据大体流程【1】软复位【2】触发测量与通讯时序&#xff08;3&#xff09;返…

日志收集笔记(Kibana,Watcher)

1 Kibana Kibana 是一个开源的分析与可视化平台&#xff0c;可以用 Kibana 搜索、查看存放在 Elasticsearch 中的数据&#xff0c;就跟谷歌的 elasticsearch head 插件类似&#xff0c;但 Kibana 与 Elasticsearch 的交互方式是各种不同的图表、表格、地图等&#xff0c;直观的…

【python】控制台中文输出乱码解决方案

注&#xff1a;最后有面试挑战&#xff0c;看看自己掌握了吗 文章目录控制台原因解决方法方法一方法二方法三如果是os.system函数乱码控制台原因 一般的情况下&#xff0c;还是我们的源码文件的编码格式问题。我们一般是要把源码文件的编码格式改成utf-8就好了&#xff0c;但是…

zeppelin安装及hive配置

一、zeppelin安装包 链接&#xff1a;https://pan.baidu.com/s/1DVmvY2TM7WmCskejTn8dzA 提取码&#xff1a;fl7r 二、安装zeppelin 将安装包传入Centos的/opt/install目录下 # 解压 tar -zxf /opt/install/zeppelin-0.10.0-bin-all.tgz -C /opt/soft/ # 重命名 mv /opt/sof…

Nodejs环境配置 | Linux安装nvm | windows安装nvm

文章目录一. 前言二. Linux Nodejs环境配置1. 安装nvm2. 配置npm三. Windows Nodejs环境配置1. 安装nvm2. 配置npm四. nvm基本使用一. 前言 由于在实际开发中一些不同的项目需要不同的npm版本来启动&#xff0c;所以本篇文章会基于nvm这个node版本管理工具来进行Linux和Winodw…

[AI助力] 2022.3.2 考研英语学习 2011 英语二翻译

[AI助力] 2022.3.2 考研英语学习 2011 英语二翻译 文章目录[AI助力] 2022.3.2 考研英语学习 2011 英语二翻译2011年英语二翻译真题自己写的看看AI的翻译谷歌翻译New Bing&#x1f602;让AI自我评价chatgpt&#x1f923;让AI自我评价DeepL有道腾讯翻译百度翻译IDEA翻译积累&…

智能家居项目(八)之树莓派+摄像头进行人脸识别

目录 1、编辑Camera.c 2、编辑contrlDevices.h 3、编辑mainPro.c 4、进行编译&#xff1a; 5、运行结果&#xff1a; ./test1 6、项目图片演示 智能家居项目&#xff08;七&#xff09;之Libcurl库与HTTPS协议实现人脸识别_Love小羽的博客-CSDN博客 经过上一篇文章&…

redhawk:Low Power Analysis

1.rush current与switch cell 在standby状态下为了控制leakage power我们选择power gating的设计方式&#xff0c;使用power switch cell关闭block/power domain的电源。 power switch的基本介绍可见: 低功耗设计-Power Switch power switch的table中有四种状态&#xff0c;…

Simulink 自动代码生成电机控制:优化Simulink生成的代码提升代码运行效率

目录 优化需求 优化方法 从模型配置优化 优化结果对比 从算法层优化 优化结果对比 总结 优化需求 本次优化的目的是提升FOC代码执行速度&#xff0c;以普通滑模观测器为例&#xff0c;优化前把速度环控制放到2ms的周期单独运行&#xff0c;把VOFA上位机通信代码放到主循…

mongodb入门到使用(上)

mongodb的安装与使用前言一、linux下载二、mongodb配置三、 mongodb服务管理启动服务查看停止四、远程连接五、SpringBoot整合总结前言 本文主要针对一些项目的部署服务器在使用方面用到了mongodb&#xff0c;参考解决一些部署方面遇到的问题。 一、linux下载 使用wget下载 w…

代数小课堂:向量代数(通过向量夹角理解不同的维度)

文章目录 引言I 计算向量的夹角1.1 毕达哥拉斯定理1.2 余弦定理1.3 计算向量的夹角II 向量夹角的应用2.1 用计算机自动筛选简历(对人进行分类)2.2 计算机进行文本自动分类的原理引言 根据余弦定理计算两个向量的夹角向量夹角的应用: 对文本进行自动分类、自动筛选简历。如果…

【上位机入门常见问题】Visual Studio 2022安装指导

Visual Studio 2022安装指导 这里给大家指导安装Visual Studio 2022 Community版本&#xff0c;也就是我们常说的社区版&#xff0c;这个版本是微软给开发者学习技术专门定制的免费版本&#xff0c;其他的专业版和企业版都是商业收费版本。对于我们学习&#xff0c;大家使用社…

使用Docker、navicat部署和连接GaussDB

一、在CentOS7上安装Docker工具 1.卸载之前老版本的Docker yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrotate \docker-logrotate \docker-selinux \docker-engine-selinux \docker-engine \docker-ce 2.安装D…

TLS协议

TLS全称传输层安全协议&#xff0c;上一代是安全套接层&#xff08;SSL,不安全&#xff09;&#xff0c;用途广泛&#xff0c;最知名的是用于http&#xff0c;使http升级为https协议&#xff0c;最新版本为TLSv1.3&#xff08;推荐使用&#xff09;。TLS通过建立客户端和服务器…

Vue-router的引入和安装

什么是Vue-Router&#xff1f;Vue路由器是Vue.js的官方路由器&#xff0c;它与Vue.js核心深度集成&#xff0c;使用Vue轻松构建单页应用程序变得轻而易举。功能包括&#xff1a;嵌套路线映射动态路由模块化&#xff0c;基于组件的路由器配置路由参数&#xff0c;查询&#xff0…