【学习笔记】深度学习实战 | LeNet

news2024/11/15 21:57:24

在这里插入图片描述

简要声明


  1. 学习相关网址
    1. [双语字幕]吴恩达深度学习deeplearning.ai
    2. Papers With Code
    3. Datasets
  2. 深度学习网络基于PyTorch学习架构,代码测试可跑。
  3. 本学习笔记单纯是为了能对学到的内容有更深入的理解,如果有错误的地方,恳请包容和指正。

参考文献


  1. PyTorch Tutorials [https://pytorch.org/tutorials/]
  2. PyTorch Docs [https://pytorch.org/docs/stable/index.html]
  3. LeNet (1998) [Gradient-based learning applied to document recognition]

简要介绍


LeNet

在这里插入图片描述

DatasetMNIST
Input (feature maps)32×32 (28×28)
CONV Layers2
FC Layers2
ActivationSigmoid
Output10

代码分析


函数库调用

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

处理数据

数据下载

# 从开放数据集中下载训练数据
train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# 从开放数据集中下载测试数据
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

print(f'Number of training examples: {len(train_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 60000
Number of testing examples: 10000

数据加载器(可选)

batch_size = 64

# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64

创建模型

# 选择训练设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device

class LeNet(nn.Module):
    def __init__(self, output_dim):
        super().__init__()

        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv_2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc_1 = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.Sigmoid()
        )

        self.fc_2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.Sigmoid()
        )

        self.fc_3 = nn.Linear(84, output_dim)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = x.view(x.size(0), -1)
        x = self.fc_1(x)
        x = self.fc_2(x)
        x = self.fc_3(x)
        return x

model = LeNet(10).to(device)
print(model)

LeNet(
(conv_1): Sequential(
(0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): Sigmoid()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(conv_2): Sequential(
(0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(1): Sigmoid()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc_1): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): Sigmoid()
)
(fc_2): Sequential(
(0): Linear(in_features=120, out_features=84, bias=True)
(1): Sigmoid()
)
(fc_3): Linear(in_features=84, out_features=10, bias=True)
)

训练模型

选择损失函数和优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

训练循环

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

测试循环

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练模型

epochs = 10.
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 10
loss: 0.015569 [ 64/60000]
loss: 0.029817 [ 6464/60000]
loss: 0.043169 [12864/60000]
loss: 0.027709 [19264/60000]
loss: 0.021492 [25664/60000]
loss: 0.011533 [32064/60000]
loss: 0.045418 [38464/60000]
loss: 0.042875 [44864/60000]
loss: 0.152001 [51264/60000]
loss: 0.040214 [57664/60000]
Test Error:
Accuracy: 98.6%, Avg loss: 0.044844

模型处理

保存模型

model_name = 'LeNet'
model_file = model_name + ".pth"
torch.save(model.state_dict(), model_file)
print("Saved PyTorch Model State to " + model_file)

Saved PyTorch Model State to LeNet.pth

Summary


安装torchsummary

pip install torchsummary

调用summary

from torchsummary import summary

model = LeNet(10).to(device)
summary(model, (1, 28, 28))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [-1, 6, 28, 28]             156
           Sigmoid-2            [-1, 6, 28, 28]               0
         MaxPool2d-3            [-1, 6, 14, 14]               0
            Conv2d-4           [-1, 16, 10, 10]           2,416
           Sigmoid-5           [-1, 16, 10, 10]               0
         MaxPool2d-6             [-1, 16, 5, 5]               0
            Linear-7                  [-1, 120]          48,120
           Sigmoid-8                  [-1, 120]               0
            Linear-9                   [-1, 84]          10,164
          Sigmoid-10                   [-1, 84]               0
           Linear-11                   [-1, 10]             850
================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.11
Params size (MB): 0.24
Estimated Total Size (MB): 0.35
----------------------------------------------------------------

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1480018.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# 获取类型 Type.GetType()

背景 C#是强类型语言,任何对象都有Type,有时候需要使用Type来进行反射、序列化、筛选等,获取Type有Type.GetType, typeof(),object.GetType() 等方法,本文重点介绍Type.GetType()。 系统类型/本程序集内的类型 对于系…

C++——模板详解

目录 模板 函数模板 显示实例化 类模板 模板特点 模板 模板,就是把一个本来只能对特定类型实现的代码,变成一个模板类型,这个模板类型能转换为任何内置类型,从而让程序员只需要实现一个模板,就能对不同的数据进行操…

2024年 前端JavaScript Web APIs 第二天 笔记

Web APIs 第二天 2.1 -事件监听以及案例 2.2 -随机点名案例 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><t…

Git分布式版本控制系统——git学习准备工作

一、Git仓库介绍 开发者可以通过Git仓库来存储和管理文件代码&#xff0c;Git仓库分为两种&#xff1a; 本地仓库&#xff1a;开发人员自己电脑上的Git仓库 远程仓库&#xff1a;远程服务器上的Git仓库 仓库之间的运转如下图&#xff1a; commit&#xff1a;提交&#xff…

Pytorch 复习总结 4

Pytorch 复习总结&#xff0c;仅供笔者使用&#xff0c;参考教材&#xff1a; 《动手学深度学习》Stanford University: Practical Machine Learning 本文主要内容为&#xff1a;Pytorch 深度学习计算。 本文先介绍了深度学习中自定义层和块的方法&#xff0c;然后介绍了一些…

Doccano 修复 spacy.gold 的bug

引言 最初只是想把Doccano标注的数据集转换成BIO(类似conll2003数据集)的标注格式&#xff1b; 按照PR的修改意见实现了修改&#xff0c;但是本人不建议这么做&#xff1b; 应该随着Doccano的升级&#xff0c;Doccano的导出格式发生了变化&#xff0c;而原来的doccano-transfo…

正确认识肠道内脆弱拟杆菌——其在健康的阴暗面和光明面

谷禾健康 脆弱拟杆菌(Bacteroides fragilis)是拟杆菌门拟杆菌属的重要成员。事实上&#xff0c;脆弱拟杆菌因其免疫调节功能而成为该属中研究最多的共生微生物。它是革兰氏阴性、不形成孢子、杆状专性厌氧菌。在人类健康中扮演着复杂而双面的角色。 这种革兰氏阴性专性厌氧菌常…

架构设计方法(4A架构)-信息架构

1、 信息架构&#xff08;IA&#xff09;&#xff1a;现实事物在IT世界的建模体现 2、数据资产目录 3、 识别业务对象&#xff1a;业务对象的设计方法 设计方法 1.基于业务流程识别业务活动。 2. 识别业务流程中每个业务活动的输入、输出等BI&#xff08;Business Item&#…

Zabbix企业运维监控工具

Zabbix企业级监控方案 常见监控软件介绍 Cacti Cacti是一套基于 PHP、MySQL、SNMP 及 RRD Tool 开发的监测图形分析工具&#xff0c;Cacti 是使用轮询的方式由主服务器向设备发送数据请求来获取设备上状态数据信息的,如果设备不断增多,这个轮询的过程就非常的耗时&#xff0…

SpringBoot源码解读与原理分析(三十七)SpringBoot整合WebMvc(二)DispatcherServlet的工作全流程

文章目录 前言12.4 DispatcherServlet的工作全流程12.4.1 DispatcherServlet#service12.4.2 processRequest12.4.3 doService12.4.3.1 isIncludeRequest的判断12.4.3.2 FlashMapManager的设计 12.4.4 doDispatch12.4.4.1 处理文件上传请求12.4.4.2 获取可用的Handler&#xff0…

Vue+SpringBoot打造农村物流配送系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统登录、注册界面2.2 系统功能2.2.1 快递信息管理&#xff1a;2.2.2 位置信息管理&#xff1a;2.2.3 配送人员分配&#xff1a;2.2.4 路线规划&#xff1a;2.2.5 个人中心&#xff1a;2.2.6 退换快递处理&#xff1a;…

5G时代对于工业化场景应用有什么改善

5G 不仅仅是 4G 的技术升级&#xff0c;而是将平板电脑和智能手机的技术升级。除了更好的高清视频流和其他高带宽应用&#xff0c;消费者不会注意到很多性能差异。然而&#xff0c;在工业领域&#xff0c;5G 代表着巨大的飞跃。 在工厂和厂房内&#xff0c; 设备的Wi-Fi 网络经…

低功耗运放D722,具有9MHz的高增益带宽积,转换速率为8.5V/μs

D722是低噪声、低电压、低功耗运放&#xff0c;应用广泛。D722具有9MHz的高增益带宽积&#xff0c;转换速率为8.5V/μs&#xff0c;静态电流为1.7mA&#xff08;5V电源电压&#xff09;。D722具有低电压、低噪声的特点&#xff0c;并提供轨到轨输出能力&#xff0c;D722的最大输…

本地maven库缓存导入私库

为了加速编译代码&#xff0c;想将本地maven缓存导入内网私库使用。 脚本网上搜的 #!/bin/bash # copy and run this script to the root of the repository directory containing files # this script attempts to exclude uploading itself explicitly so the script name …

C++指针(二)

个人主页&#xff1a;PingdiGuo_guo 收录专栏&#xff1a;C干货专栏 文章目录 1.数组指针 1.1数组指针的概念 1.2数组指针的用处 1.3数组指针的操作 1.4二维数组如何访问 1.5数组指针访问流程 1.6数组指针的练习题 2.指针数组 2.1指针数组的概念 2.2指针数组的用处 2…

AMEYA360:航顺车规级MCU HK32AUTO39A的汽车侧滑门控制方案

汽车滑门因侧开启方式与传统车门相比&#xff0c;具有易泊车、开启宽度大和方便乘员货物进出的优点&#xff0c;很受消费者的青睐。汽车市场上&#xff0c;无论是面向高端的商务豪华MPV&#xff0c;还是面向城市物流的轻型客车和低端客运微型车都采用了汽车机械滑门系统。 汽车…

韦东山嵌入式Liunx入门驱动开发三

文章目录 一、GPIO和Pinctrl子系统的使用1-1 Pinctrl子系统1-2 GPIO子系统1-3 基于GPIO子系统的LED驱动程序 本人学习完韦老师的视频&#xff0c;因此来复习巩固&#xff0c;写以笔记记之。 韦老师的课比较难&#xff0c;第一遍不知道在说什么&#xff0c;但是坚持看完一遍&…

死记硬背spring bean 的生命周期

1.bean的生命周期 我们平常经常使用类似于new Object()的方式去创建对象&#xff0c;在这个对象没有任何引用的时候&#xff0c;会被gc给回收掉。而对于spring而言&#xff0c;它本身存在一个Ioc容器&#xff0c;就是用来管理对象的&#xff0c;而对象的生命周期也完全由这个容…

为什么软考报名人数越来越多?

2020年软考报名人数404666人&#xff0c;广东省报考人数超过14万人。 ●2021年软考通信考试报名人数突破100万人&#xff0c;估计软考有90多万。 ●2022年软考通信考试共129万人&#xff0c;估计软考占了120多万人。 ●2023年软考具体报名人数没有公布&#xff0c;但工业和信…

three 层级模型

group.remove(mesh1,mesh2);Vector3与模型位置、缩放属性 Group层级模型(树结构) 创建了两个网格模型mesh1、mesh2&#xff0c;通过THREE.Group类创建一个组对象group,然后通过add方法把网格模型mesh1、mesh2作为设置为组对象group的子对象&#xff0c;然后在通过执行scene.a…