经典卷积神经网络 LeNet

news2025/1/15 23:43:12

一、实例图片

#我们传入的是28*28,所以加了padding
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

二、总结

1、LeNet是早期成功的神经网络

2、先使用卷积层来学习图片空间信息

3、然后使用全连接层来转换到类别空间

三、代码

1、评估模型,将参数放到GPU中

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        #eval是将模型设置为评估模式,评估模式就不会改变模型参数了可以用来预测结果;eval就是关闭模型中的dropout功能,调到评价模式;与之相对的是train()
        net.eval()
        if not device:
            #如果未提供设备参数,则使用模型第一个参数的设备作为默认设备。这确保了数据和模型在            同一设备上运行
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    #确保在计算精度时不会计算梯度,从而节省显存和提高计算效率
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

2、训练模型

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    #对于net里面的所有parameter,都去run一下那个初始化权重的函数。就是说在整个net中的所有层
     上面都使用init__weights函数来初始化所有现行层和卷积层的权重
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            #xavier能够根据输入输出的大小,使得初始化随机权重能使,输入和输出的方差差别不会
             很大,保证在模型最开始的时候,结果不会指数爆炸或者消失
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    #SGD是随机梯度下降算法
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                #l是当前批次的平均损失。X.shape[0]是当前批次的样本数。l * X.shape[0]
                 计算的是当前批次的总损失。
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]

            #这里是可视化
            #(i + 1) % (num_batches // 5)每训练到一个阶段时(5次中的每一次)会更新可视化数据
            #i == num_batches - 1:当训练到最后一个批次时,条件也会满足
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1895478.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android EditText的属性与用法

EditText 是编辑框控件,可以接收用户输入,并在程序中对用户输入进行处理。EditText在App里随处可见,在进行搜索、聊天、拨号等需要输入信息的场合,都可以使用 EditText。 图1 编辑框示意图 EditText 是TextView的子类&#xff0c…

吴恩达深度学习笔记:机器学习策略(2)(ML Strategy (2)) 2.5-2.6

目录 第三门课 结构化机器学习项目(Structuring Machine Learning Projects)第二周:机器学习策略(2)(ML Strategy (2))2.5 数据分布不匹配时的偏差与方差的分析(Bias and Variance with mismatched data di…

下载安装MySQL

1.软件的下载 打开官网下载mysql-installer-community-8.0.37.0.msi 2.软件的安装 mysql下载完成后,找到下载文件,双击安装 3.配置环境变量 4.自带客户端登录与退出

CSS filter(滤镜)属性,并实现页面置灰效果

目录 一、filter(滤镜)属性 二、准备工作 三、常用的filter属性值 1、blur(px) 2、brightness(%) 3、contrast(%) 4、grayscale(%) 5、opacity(%) 6、saturate(%) 7、sepia(%) 8、invert(%) 9、hue-rotate(deg) 10、drop-shadow(h-shadow v…

前端JS 插件实现下载【js-tool-big-box,下载大文件(fetch请求 + 下载功能版)

上一节,我们添加了下载大文件的纯功能版,意思就是需要开发者,在自己项目里发送请求,请求成功后,获取文件流的blob数据,然后 js-tool-big-box 帮助下载。 但考虑到,有些项目,可能比较…

装饰模式解析:基本概念和实例教程

目录 装饰模式装饰模式结构装饰模式应用场景装饰模式优缺点练手题目题目描述输入描述输出描述题解 装饰模式 装饰模式,又称装饰者模式、装饰器模式,是一种结构型设计模式,允许你通过将对象放入包含行为的特殊封装对象中来为原对象绑定新的行…

面试篇-Redis-1缓存三兄弟+数据一致性

文章目录 前言一、你们项目中使用Redis都做了什么:二、使用过程中遇到缓存穿透,缓存击穿,缓存雪崩你们如何处理:2.1 缓存穿透:2.1.1 通过缓存key值为null 进行处理:2.1.2 使用布隆过滤器:2.1.3 …

OpenCV基础(2)

目录 滤波处理 均值滤波 基本原理 函数用法 程序示例 高斯滤波 基本原理 函数用法 程序示例 中值滤波 基本原理 函数用法 程序示例 形态学 腐蚀 膨胀 通用形态学函数 前言:本部分是上一篇文章的延续,前面部分请查看:OpenCV…

深入理解如何撤销 Git 中不想提交的文件

个人名片 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] &#x1f4f1…

图增强LLM + 可穿戴设备实时数据,生成个性化健康见解

图增强LLM 可穿戴设备实时数据,生成个性化健康见解 提出背景图增强LLM 子解法1(使用层次图模型) 子解法2(动态数据整合) 子解法3(LLM引导评估) 提出背景 论文:https://arxiv.or…

【js正则】去除文本中的a标签及其内容

场景&#xff1a;有时候服务端返回的文本中&#xff0c;包含a标签&#xff0c;前端不需要展示。 // 示例 const inputText 【提醒&#xff1a;XXXX】\nXXXXXX: 1\n\n<a href"https://export.shobserver.com/baijiahao/html/767805.html">详情</a>;JS正…

【营销策划模型大全】私域运营必备

营销策划模型大全&#xff1a;战略屋品牌屋、电商运营模型、营销战略、新媒体运营模型、品牌模型、私域运营模型…… 该文档是一份策划总监工作模型的汇总&#xff0c;包括战略屋/品牌屋模型、营销战略模型、品牌相关模型、电商运营模型、新媒体运营模型和私域运营模型等&…

JavaScript基础-函数(完整版)

文章目录 函数基本使用函数提升函数参数arguments对象&#xff08;了解&#xff09;剩余参数(重点)展开运算符(...) 逻辑中断函数参数-默认参数函数返回值-return作用域(scope)全局作用域局部作用域变量的访问原则垃圾回收机制闭包 匿名函数函数表达式立即执行函数 箭头函数箭头…

全自动内衣洗衣机什么牌子好?四大热门内衣洗衣机多角度测评

内衣洗衣机是近几年新兴的一种家用电器产品&#xff0c;正日益引起人们的重视。但是&#xff0c;面对市面上品牌繁多、款式繁多的内衣洗衣机&#xff0c;使得很多人都不知道该如何选择。身为一个数码家电博主&#xff0c;我知道这类产品在挑选方面有着比较深入的了解。为此&…

AIGC对设计师积极性的影响

随着科技的迅猛发展&#xff0c;生成式人工智能&#xff08;AIGC&#xff09;工具正逐渐深入设计的每个角落&#xff0c;对设计师的工作方式和思维模式产生了深远的影响。AIGC不仅极大提升了设计师的工作效率&#xff0c;更激发了他们的创新思维&#xff0c;为设计行业带来了翻…

好文阅读-日志篇

https://mp.weixin.qq.com/s/jABbG4MKvEiWXwdYwUk8SA 这里直接看最佳实践。 Maven 依赖 <dependencyManagement><dependencies><dependency><groupId>org.slf4j</groupId><artifactId>slf4j-api</artifactId><version>1.7.36…

聊聊 CTO 和 技术总监的区别

前言 CTO&#xff08;Chief Technology Officer&#xff09;&#xff0c;是首席技术官的意思。 技术总监&#xff0c;顾名思义&#xff0c;就是负责指导和监督公司的技术团队&#xff0c;确保技术和产品的开发与创新顺利进行。 有的软件公司同时有 CTO 和技术总监&#xff0…

第二届计算机、视觉与智能技术国际会议(ICCVIT 2024)

随着科技的飞速发展&#xff0c;计算机、视觉与智能技术已成为推动现代社会进步的重要力量。为了汇聚全球顶尖专家学者&#xff0c;共同探讨这一领域的最新研究成果和前沿技术&#xff0c;第二届计算机、视觉与智能技术国际会议&#xff08;ICCVIT 2024&#xff09;将于2024年1…

JAVA高级进阶11多线程

第十一天、多线程 线程安全问题 线程安全问题 多线程给我们带来了很大性能上的提升,但是也可能引发线程安全问题 线程安全问题指的是当个多线程同时操作同一个共享资源的时候,可能会出现的操作结果不符预期问题 线程同步方案 认识线程同步 线程同步 线程同步就是让多个线…

swiftui中几个常用的手势控制单击点击,双击和长按事件

简单做了一个示例代码&#xff0c;包含三个圆形形状&#xff0c;配置了不同的事件&#xff0c;示例代码&#xff1a; // // RouterView.swift // SwiftBook // // Created by song on 2024/7/4. //import SwiftUIstruct RouterView: View {State var isClick falsevar bod…