人工智能算法工程师(中级)课程9-PyTorch神经网络之全连接神经网络实战与代码详解

news2024/9/20 22:30:20

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程9-PyTorch神经网络之全连接神经网络实战与代码详解。本文将给大家展示全连接神经网络与代码详解,包括全连接模型的设计、数学原理介绍,并从手写数字识别到猫狗识别实战演练。

文章目录

  • 一、引言
  • 二、全连接模型的设计
    • 1. 神经元模型
    • 2. 网络结构
  • 三、全连接模型的参数计算
    • 1. 前向传播
    • 2. 反向传播
  • 四、全连接模型实现手写数字识别
    • 1. 数据准备
    • 2. 模型构建
    • 3. 代码实现
  • 五、阶段实战:猫狗识别
    • 1. 数据准备
    • 2. 模型构建
    • 3. 代码实现
  • 六、数学原理详解
    • 1. 激活函数
    • 2. 损失函数
    • 3. 优化算法
  • 七、总结

一、引言

全连接神经网络(Fully Connected Neural Network,FCNN)是一种经典的神经网络结构,它在众多领域都有着广泛的应用。本文将详细介绍全连接神经网络的设计、参数计算及其在图像识别任务中的应用。通过本文的学习,读者将掌握全连接神经网络的基本原理,并能够实现手写数字识别和猫狗识别等实战项目。

二、全连接模型的设计

1. 神经元模型

全连接神经网络的基本单元是神经元,其数学表达式为:
f ( x ) = σ ( ∑ i = 1 n w i x i + b ) f(x) = \sigma(\sum_{i=1}^{n}w_ix_i + b) f(x)=σ(i=1nwixi+b)
其中, x x x 为输入向量, w w w 为权重向量, b b b 为偏置, σ \sigma σ 为激活函数。

2. 网络结构

全连接神经网络由输入层、隐藏层和输出层组成。每一层的神经元都与上一层的所有神经元相连,如图1所示。
在这里插入图片描述

三、全连接模型的参数计算

1. 前向传播

假设一个全连接神经网络共有 l l l层,第 k k k层的输入为 X ( k ) X^{(k)} X(k),输出为 Y ( k ) Y^{(k)} Y(k),则有:
Y ( k ) = σ ( W ( k ) X ( k ) + b ( k ) ) Y^{(k)} = \sigma(W^{(k)}X^{(k)} + b^{(k)}) Y(k)=σ(W(k)X(k)+b(k))
其中, W ( k ) W^{(k)} W(k) b ( k ) b^{(k)} b(k) 分别为第 k k k层的权重和偏置。

2. 反向传播

全连接神经网络的参数更新通过反向传播算法实现。对于输出层,损失函数为:
L = 1 2 ( Y t r u e − Y p r e d ) 2 L = \frac{1}{2}(Y_{true} - Y_{pred})^2 L=21(YtrueYpred)2
其中, Y t r u e Y_{true} Ytrue 为真实标签, Y p r e d Y_{pred} Ypred 为预测值。
根据链式法则,输出层的权重梯度为:
∂ L ∂ W ( l ) = ∂ L ∂ Y ( l ) ⋅ ∂ Y ( l ) ∂ Z ( l ) ⋅ ∂ Z ( l ) ∂ W ( l ) \frac{\partial L}{\partial W^{(l)}} = \frac{\partial L}{\partial Y^{(l)}} \cdot \frac{\partial Y^{(l)}}{\partial Z^{(l)}} \cdot \frac{\partial Z^{(l)}}{\partial W^{(l)}} W(l)L=Y(l)LZ(l)Y(l)W(l)Z(l)
其中, Z ( l ) = W ( l ) X ( l ) + b ( l ) Z^{(l)} = W^{(l)}X^{(l)} + b^{(l)} Z(l)=W(l)X(l)+b(l)
同理,可求得输出层的偏置梯度、隐藏层的权重梯度和偏置梯度。

四、全连接模型实现手写数字识别

1. 数据准备

使用MNIST数据集,包含60000个训练样本和10000个测试样本。

2. 模型构建

构建一个简单的全连接神经网络,包含一个输入层(784个神经元)、两个隐藏层(128个神经元)和一个输出层(10个神经元)。
在这里插入图片描述

3. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.model(x)

# 加载数据
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# 初始化模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 训练模型
for epoch in range(5):
    for i, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 评估模型
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

五、阶段实战:猫狗识别

1. 数据准备

使用猫狗数据集,包含25000张猫和狗的图片。我们将猫和狗的照片放在目录’data/train’下。

2. 模型构建

构建一个全连接神经网络,包含一个输入层(64643个神经元)、三个隐藏层(256、128、64个神经元)和一个输出层(2个神经元)。

3. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据预处理
data_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomRotation(40),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomAffine(0, translate=(0.2, 0.2), scale=(0.8, 1.2)),
    transforms.ToTensor(),
])

# 加载数据
train_dataset = datasets.ImageFolder('data/train', transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*64*3, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 初始化模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

# 训练模型
for epoch in range(15):
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.float().unsqueeze(1).to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 评估模型
# 假设有一个测试数据集的加载器叫做 validation_loader
correct = 0
total = 0
with torch.no_grad():
    for images, labels in validation_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

六、数学原理详解

1. 激活函数

激活函数用于引入非线性因素,使得神经网络能够学习和模拟复杂函数。常用的激活函数有:

  • Sigmoid函数: σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1
  • ReLU函数: R e L U ( x ) = max ⁡ ( 0 , x ) ReLU(x) = \max(0, x) ReLU(x)=max(0,x)
  • Softmax函数: s o f t m a x ( x ) i = e x i ∑ j e x j softmax(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax(x)i=jexjexi

2. 损失函数

损失函数用于衡量模型预测值与真实值之间的差异。常用的损失函数有:

  • 均方误差(MSE): M S E ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE(y, \hat{y}) = \frac{1}{n}\sum_{i=1}^{n}(y_i - \hat{y}_i)^2 MSE(y,y^)=n1i=1n(yiy^i)2
  • 交叉熵损失:对于二分类问题, C E ( y , y ^ ) = − y log ⁡ ( y ^ ) − ( 1 − y ) log ⁡ ( 1 − y ^ ) CE(y, \hat{y}) = -y\log(\hat{y}) - (1-y)\log(1-\hat{y}) CE(y,y^)=ylog(y^)(1y)log(1y^)

3. 优化算法

优化算法用于更新网络的权重和偏置,以最小化损失函数。常用的优化算法有:

  • 梯度下降(Gradient Descent): w : = w − α ∂ L ∂ w w := w - \alpha \frac{\partial L}{\partial w} w:=wαwL
  • Adam优化器:结合了动量(Momentum)和自适应学习率(Adagrad)的优点。

七、总结

本篇文章从全连接神经网络的基本原理出发,介绍了全连接模型的设计、参数计算以及如何实现手写数字识别和猫狗识别。通过配套的完整可运行代码,读者可以更好地理解全连接神经网络的实现过程。在实际应用中,全连接神经网络虽然已被卷积神经网络(CNN)等更先进的网络结构所取代,但其基本原理仍然是深度学习领域的重要基石。希望本文能帮助读者深入掌握全连接神经网络,并为后续学习打下坚实的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925802.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM监控及诊断工具-命令行篇-jstat命令介绍

JVM监控及诊断工具-命令行篇01-jstat:查看JVM统计信息 一 基本情况二 基本语法2.1 option参数1. 类装载相关的:2. 垃圾回收相关的-gc:显示与GC相关的堆信息。包括Eden区、两个Survivor区、老年代、永久代等的容量、已用空间、GC时间合计等信息…

基于51单片机的多路报警器Protues仿真设计

一、设计背景 随着社会的发展和技术的进步,安全问题越来越受到重视。各种工业设施、家庭、商业场所以及公共场所的安全保障成为了重点。报警器作为安全防护系统的重要组成部分,在预防和及时应对各种突发事件中发挥着至关重要的作用。传统的报警器通常在功…

【C++】哈希(散列)表

目录 一、哈希表的基本概念1.哈希的概念2.哈希冲突2.1 哈希函数2.2 哈希冲突的解决办法2.2.1 闭散列2.2.2 开散列 二、哈希表的实现1.闭散列的实现1.1 闭散列的结构1.2 闭散列的插入1.3 闭散列的删除1.4 闭散列的查找 2.开散列的实现2.1 key值不能取模的情况2.2 开散列的结构2.…

Redis的安装配置及IDEA中使用

目录 一、安装redis,配置redis.conf 1.安装gcc 2.将redis的压缩包放到指定位置解压 [如下面放在 /opt 目录下] 3.编译安装 4.配置redis.conf文件 5.开机自启 二、解决虚拟机本地可以连接redis但是主机不能连接redis 1.虚拟机网络适配器网络连接设置为桥接模式…

STM32智能空气质量监测系统教程

目录 引言环境准备智能空气质量监测系统基础代码实现:实现智能空气质量监测系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景:空气质量监测与优化问题解决方案与优化收尾与总结 1. 引言 智能空…

【练习】分治--归并排序

🎥 个人主页:Dikz12🔥个人专栏:算法(Java)📕格言:吾愚多不敏,而愿加学欢迎大家👍点赞✍评论⭐收藏 目录 归并排序 代码实现 交易逆序对的总数 题目描述 ​编辑 题解 代码实…

基于Vue和UCharts的前端组件化开发:实现高效、可维护的词云图与进度条组件

基于Vue和UCharts的前端组件化开发:实现高效、可维护的词云图与进度条组件 摘要 随着前端技术的迅速发展和业务场景的日益复杂,传统的整块应用开发方式已无法满足现代开发的需求。组件化开发作为一种有效的解决方案,能够将系统拆分为独立、…

汽车的驱动力,是驱动汽车行驶的力吗?

一、地面对驱动轮的反作用力? 汽车发动机产生的转矩,经传动系传至驱动轮上。此时作用于驱动轮上的转矩Tt产生一个对地面的圆周力F0,地面对驱动轮的反作用力Ft(方向与F0相反)即是驱动汽车的外力,此外力称为汽车的驱动力。 即汽车…

C++中跨平台类的设计方法

目录 1.引言 2.具体实现 2.1.单一继承实现 2.2.桥接方式实现 3.总结 1.引言 进行C代码的跨平台设计,主要目标是确保编写的代码能够在不同的操作系统(如Windows、Linux、macOS等)和硬件架构(如x86、ARM等)上无缝运…

leetcode--二叉树中的最大路径和

leetcode地址:二叉树中的最大路径和 二叉树中的 路径 被定义为一条节点序列,序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点,且不一定经过根节点。 路径和 是路径中各节点值的总…

编译x-Wrt 全过程

参考自;​​​​​​c编译教程 | All about X-Wrt 需要详细了解的小伙伴还请参看原文 ^-^ 概念: x-wrt(基于openwrt深度定制的发行版本) 编译系统: ubuntu22.04 注意: 特别注意的是,整个编译过程,都是用 …

C到C嘎嘎的衔接篇

本篇文章,是帮助大家从C向C嘎嘎的过渡,那么我们直接开始吧 不知道大家是否有这样一个问题,学完C的时候感觉还能听懂,但是听C嘎嘎感觉就有点难度或者说很难听懂,那么本篇文章就是帮助大家从C过渡到C嘎嘎。 C嘎嘎与C的区…

代码随想录算法训练营第三十四天|322. 零钱兑换、279.完全平方数

322. 零钱兑换 给定不同面额的硬币 coins 和一个总金额 amount。编写一个函数来计算可以凑成总金额所需的最少的硬币个数。如果没有任何一种硬币组合能组成总金额,返回 -1。 dp[j]凑足总额为j的方法有dp[j]个。 递推公式:凑足金额为j-coins[i]的方法有d…

linux下安装cutecom串口助手;centos安装cutecom串口助手;rpm安装包安装cutecom串口助手

在支持apt-get的系统下安装 在终端命令行中输入: sudo apt-get install cutecom 安装好后输入 sudo cutecom 就可以了 关于如何使用,可以看这个https://www.cnblogs.com/xingboy/p/14388610.html 如果你的电脑不支持apt-get。 那我们就通过安装包…

【Android】kotlin jdk版本冲突与Kotlin依赖管理插件

1、androidx.activity:activity:1.8.0 依赖版本错误问题 *依赖项“androidx.activity:activity:1.8.0”要求依赖它的库和应用针对版本 34 或更高版本 Android API 进行编译。:app 目前是针对 android-33 编译的。此外…

QTabWidget、QListWidget、QStackedWidget

The QTabWidget class provides a stack of tabbed widgets. More... The QListWidget class provides an item-based list widget. More... QStringList strlist;strlist<<"系统"<<"外观"<<"截图"<<"贴图"…

FPGA设计之跨时钟域(CDC)设计篇(1)----亚稳态到底是什么?

1、什么是亚稳态? 在数字电路中,如果数据传输时不满足触发器FF的建立时间要求Tsu和保持时间要求Th,就可能产生亚稳态(Metastability),此时触发器的输出端(Q端)在有效时钟沿之后比较长的一段时间都会处于不确定的状态(在0和1之间振荡),而不是等于数据输入端(D端)的…

Java NIO 总结: NIO技术基础回顾

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

k8s集群新增节点

目前集群状态 如K8S 集群搭建中规划的集群一样 Masternode01node02IP192.168.100.100192.168.100.101192.168.100.102OSCent OS 7.9Cent OS 7.9Cent OS 7.9 目前打算新增节点node03 Masternode01node02node03IP192.168.100.100192.168.100.101192.168.100.102192.168.100.1…

Java核心篇之JVM探秘:对象创建与内存分配机制

系列文章目录 第一章 Java核心篇之JVM探秘&#xff1a;内存模型与管理初探 第二章 Java核心篇之JVM探秘&#xff1a;对象创建与内存分配机制 第三章 Java核心篇之JVM探秘&#xff1a;垃圾回收算法与垃圾收集器 第四章 Java核心篇之JVM调优实战&#xff1a;Arthas工具使用及…