神经网络学习笔记——如何设计、实现并训练一个标准的前馈神经网络

news2024/11/16 21:55:13

1.从零设计并训练一个神经网络icon-default.png?t=O83Ahttps://www.bilibili.com/video/BV134421U77t/?spm_id_from=333.337.search-card.all.click&vd_source=0b1f472915ac9cb9cdccb8658d6c2e69

一、如何设计、实现并训练一个标准的前馈神经网络,用于手写数字图像的分类,重点讲解了神经网络的设计和实现、数据的准备和处理、模型的训练和测试流程。

- 以数字图像作为输入,神经网络计算并识别图像中的数字。

- 输入层包含784个神经元,隐藏层用于特征提取,输出层包含10个神经元。

- 输出层的输出输入到soft max层,将十维的向量转换为十个概率值。

二、神经网络的设计思路和实现方法,以及手写数字识别的数据处理流程和代码实现,包括图像预处理、构建数据集等。

- 神经网络设计思路:每个概率值对应一个数字

- 手写数字识别训练数据:使用mini数据集

  • 数据处理流程:图像预处理、读取数据文件夹、构建数据集

三、使用PyTorch进行图像分类的步骤,包括读取数据、构建数据集、小批量数据读取、模型训练等,以及涉及到的对象和损失函数等。

  • 1、读取数据、构建数据集
  • 2、模型的训练
  • 使用train loader进行小批量数据读入,创建模型、优化器和损失函数进行训练
  • 训练模型的循环迭代,外层代表整个数据集的遍历次数,内层使用小批量数据读取进行梯度下降算法。
  • 3、模型的测试 
  • 注:测试的时候,需要​编辑model.eval() 
import torch
import torch.nn as nn

# 定义模型结构
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)
        self.dropout = nn.Dropout(0.5)
        self.batch_norm = nn.BatchNorm1d(2)

    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        x = self.batch_norm(x)
        return x

# 初始化模型
model = SimpleModel()

# 加载训练好的模型权重
model.load_state_dict(torch.load('model.pth'))

# 将模型设置为评估模式
model.eval()

# 测试数据
test_input = torch.randn(5, 10)

# 禁用梯度计算
with torch.no_grad():
    output = model(test_input)
    print(output)

Q1:为什么训练集要分批次训练,跟每条数据单独训练(batch_size=1)有什么不一样的吗?

  • 较大的 batch_size,梯度更新会更加平滑和稳定,模型能够更好地学到数据的总体分布特征。
  • 最优的batch size跟训练集的大小有关,大数据集适合大batch,小数据集适合小batch,极端情况下batch_size=1也不是不可以。

 Q2:为什么loss会不断变小?

  • 梯度下降只包含了局部的损失函数信息,所以只能保证存在趋近局部最优的可能。

Loss 在训练过程中不断变小是因为优化算法(如梯度下降)的作用,但这个现象背后有多个原因和理论支持。逐步解析:

1. 梯度下降原理

梯度下降算法的核心思想是利用目标函数(即损失函数)的梯度来迭代地更新模型的参数。梯度本身指示了损失函数增长最快的方向,因此,通过向梯度的反方向更新参数,可以逐步减小损失值。

2. 局部最优与全局最优

  • 局部最优:在多维空间中,损失函数可能存在多个局部最小值。梯度下降算法只能保证找到其中一个局部最小值,而不一定是全局最小值。
  • 全局最优:对于凸函数,任何局部最小值也是全局最小值。但对于非凸函数(大多数深度学习模型的损失函数),找到全局最小值更加复杂。

3. 损失函数的性质

  • 凸性:如果损失函数是凸的,那么任何局部最小值也是全局最小值,梯度下降法最终能够找到这个全局最小值。
  • 非凸性:对于非凸函数,虽然存在多个局部最小值,但梯度下降法依然可以找到某个局部最小值,使得损失函数值减小。

4. 学习率的作用

  • 学习率是梯度下降中一个关键的超参数,它决定了每一步参数更新的幅度。适当选择学习率可以保证算法的收敛性和稳定性。

5. 损失函数的优化目标

  • 训练过程中,优化的目标是最小化损失函数,这通常意味着模型的预测误差在减少
  • 随着训练的进行,模型逐渐学习到数据中的模式和结构,使得预测更加准确,从而损失值减小。

6. 泛化能力

  • 虽然训练过程中损失持续减小,但最终目标是提高模型在未知数据上的泛化能力
  • 为了防止过拟合,通常会采取正则化技术(如L1、L2正则化,Dropout等),以及早停(early stopping)策略。

7. 局部信息与全局搜索

  • 梯度下降利用的是局部信息(即当前位置的梯度),它提供了一种贪婪的搜索策略,每一步都朝着减少损失的方向前进。
  • 尽管只能保证趋近局部最优,但在实际应用中,通过合理的初始化、学习率调度和正则化策略,梯度下降往往能找到使损失足够小的参数配置。

结论

损失函数不断变小是因为梯度下降算法通过利用局部梯度信息来不断更新模型参数,使得模型逐渐学习到数据的内在规律,从而减少预测误差。虽然梯度下降只能保证找到局部最优解,但通过适当的策略和技巧,通常可以训练出性能良好的模型。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2133644.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何制作Vector Vflash中加载的DLL文件--自动解锁刷写过程中27服务

案例背景: vFlash 是一种易于使用的工具,用于对一个或多个 ECU 进行刷写软件。由于方法灵活,它可以支持各种汽车原始设备制造商的不同刷写规范。它支持通过 CAN、CAN FD、FlexRay、LIN、以太网/DoIP 和以太网/SoAd 对 ECU 进行刷写。 vFlas…

SpringSecurity原理解析(六):SecurityConfigurer 解析

1、SecurityConfigurer SecurityConfigurer 在 Spring Security 中是一个非常重要的接口,观察HttpSecurity 中的很多 方法可以发现,SpringSecurity 中的每一个过滤器都是通过 xxxConfigurer 来进行配置的,而 这些 xxxConfigurer 其实都是 Sec…

针对Docker容器的可视化管理工具—DockerUI

目录 ⛳️推荐 前言 1. 安装部署DockerUI 2. 安装cpolar内网穿透 3. 配置DockerUI公网访问地址 4. 公网远程访问DockerUI 5. 固定DockerUI公网地址 ⛳️推荐 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下…

GBI(生成式商业智能)实际业务生产落地运用上的探索和实践

前言 最近在探索如何发展AI在业务上的驱动力时了解到了生成式商业智能这一概念,同时本人也在探索ChatBI这一技术的实际落地运用,其实二者几乎在实现效果层面是一个意思,GBI(Generative Business Intelligence)是偏向业务方面,而C…

[000-01-008].第05节:OpenFeign高级特性-超时控制

我的后端学习大纲 SpringCloud学习大纲 1.1.OpenFeign超时的情况: 在Spring Cloud微服务架构中,大部分公司都是利用OpenFeign进行服务间的调用,而比较简单的业务使用默认配置是不会有多大问题的,但是如果是业务比较复杂&#xff…

UiBot教程:实现复杂流程图的高效方法

在自动化测试和RPA(机器人流程自动化)领域,使用UiBot绘制复杂流程图是日常工作中常见的挑战之一。如何在繁杂的逻辑中保持高效?如何实现复杂流程的自动化设计而不迷失于其中?这是许多测试工程师和自动化开发者所面临的…

区块链之变:揭秘Web3对互联网的改变

传统游戏中,玩家的虚拟资产(如角色、装备)通常由游戏公司控制,玩家无法真正拥有这些资产或进行交易。而在区块链游戏中,虚拟资产通过去中心化技术记录在区块链上,玩家对其拥有完全的所有权,并能…

Loki 分布式日志中心服务

目录 Loki 是什么 Loki 配置文件介绍 Loki 安装 Promtail 配置文件介绍 Promtail 安装 Loki 整合 Grafana Loki 是什么 Loki 是一个专为日志聚合和查询设计的开源分布式日志管理系统,由 Grafana Labs 开发。它与 Prometheus 类似,但用于处理日志&a…

决策树实战

文章目录 一、入门基础案例二、基于sklearn的决策树模型2.1sklearn中的决策树实现2.2分类型决策树:DecisionTreeClassifier2.2.1重要参数2.2.2重要属性与接口2.2.3基本案例:wine葡萄酒数据集 2.3回归型决策树:DecisionTreeRegressor2.3.1重要…

大学选修课无人机航拍技术与技巧怎么样?

在当今这个视觉盛行的时代,无人机航拍技术以其独特的视角和非凡的创意能力,正逐步成为影视制作、新闻报道、地理测绘、环境监测及个人记录生活等领域不可或缺的工具。为此,本大学特设《无人机航拍技术与技巧》选修课,旨在通过系统…

Linux数据相关-第3个服务-实时同步sersync

1、实时同步 背景: 之前我们通过rsync 定时任务实现定时备份/同步对于NFS我们需要进行实时同步 选择 分布式存储.。使用实时同步服务NFS。选择公有云对象存储,七牛存储,腾讯存储COS 选择:nfs实时同步工具 inotify(bug需要书…

3D点云目标检测数据集标注工具 保姆级教程——CVAT (附json转kitti代码)

前言: 笔者尝试过很多3D标注软件都遇到很多问题,例如CloudCompare不适合做3D目标检测的数据集而且分割地面的时很繁琐;labelCloud没有三视图,视角难以调整标得不够精确;SUSTechPOINTS换帧麻烦、输出时存储在docker里面…

每日OJ_牛客_数字统计(简单模拟)

目录 牛客_数字统计(简单模拟) 解析代码 牛客_数字统计(简单模拟) [NOIP2010]数字统计_牛客题霸_牛客网 描述 请统计某个给定范围[L, R]的所有整数中,数字2出现的次数。 比如给定范围[2, 22],数字2在数…

sipp模拟uas发送reinvite

概述 freeswitch是一款简单好用的VOIP开源软交换平台。 在更新了sipp模拟update的配置方案之后,我希望对比一下fs对update和reinvite的处理流程。 本文档记录sipp的配置方案,该方案中包含了update和reinvite的信令。 环境 CentOS 7.9 freeswitch 1…

Linux入门攻坚——32、Mini Linux制作

制作一个mini linux,需要对linux的启动流程很熟悉,这里又一次学习Linux的启动过程。 启动流程:CentOS 6 / 5: POST -->BootSequence(BIOS) --> BootLoader --> kernel (ramdisk) --> rootfs --> /sbin/init …

MySQL——数据类型(二)

目录 一、日期与时间类型 1.1 date 1.2 datetime 1.3 timestamp 二、枚举和联合 2.1 enum 2.2 set 2.2.1 set 的插入 2.2.2 set 的查找 思维导图可以参考如下链接: 数据类型.xmind 夜夜亮晶晶/MySQL - Gitee.com 一、日期与时间类型 1.1 date 日期 yyy…

2024 年最佳 Chrome 验证码扩展,解决 reCAPTCHA 问题

验证码,特别是 reCAPTCHA,已成为在线安全的不可或缺的一部分。虽然它们在区分人类和机器人方面起着至关重要的作用,但它们也可能成为合法用户和从事网络自动化的企业的主要障碍。无论您是试图简化在线体验的个人,还是依赖自动化工…

easy-es动态索引支持

背景 很多项目目前都引入了es,由于es弥补了mysql存储及搜索查询的局限性,随着技术的不断迭代,原生的es客户端使用比较繁琐不直观,上手代价有点大,所以easy-es框架就面世了,学习成本很低,有空大…

Ubuntu下安装最新版本Apache2文件服务器

文章目录 1.最新版本Apache2安装2. Apache2配置2.1 端口配置2.2 创建软连接,生成文件服务2.3 隐藏Apache2服务版本号2.4 添加用户,设置Apache2文件服务密码2.5 重启Apache2服务3. 执行后效果 1.最新版本Apache2安装 注意:安装最新版本必须升级Ubuntu为20…

Linux 中System V IPC的共享内存

1. 概念介绍 System V IPC(Inter-Process Communication)是一组在UNIX系统中用于进程间通信的机制,包括共享内存、消息队列和信号量。这些机制由System V内核提供,并且它们的存在不依赖于创建它们的进程,而是由内核管…