AI-基本概念-多层感知器模型/CNN/RNN/自注意力模型

news2024/11/6 9:42:09

1 需求

神经网络

……


深度学习

……


深度学习包含哪些神经网络:

  • 全连接神经网络
  • 卷积神经网络
  • 循环神经网络
  • 基于注意力机制的神经网络


2 接口


3 CNN

在这个示例中:

 
  • 首先定义了一个简单的卷积神经网络SimpleCNN,它包含两个卷积层、两个池化层和两个全连接层。
  • 然后通过torchvision库加载了 MNIST 数据集,并进行了数据预处理。
  • 接着使用交叉熵损失函数和随机梯度下降优化器对模型进行了 10 个周期的训练。
  • 最后在测试集上对模型进行了测试,计算了模型的准确率。这是一个基础的 PyTorch CNN 应用示例,你可以根据实际需求修改模型结构、数据和训练参数等。

第一步,定义卷积神经网络(CNN)模型

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 第一个卷积层,输入通道为1(灰度图像),输出通道为32,卷积核大小为3x3
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        # 第一个卷积层后的激活函数ReLU
        self.relu1 = nn.ReLU()
        # 第一个最大池化层,池化核大小为2x2
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        # 第二个卷积层,输入通道为32,输出通道为64,卷积核大小为3x3
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        # 全连接层,将卷积层输出的特征图展平后连接到该层,输入大小为64 * 6 * 6,输出大小为128
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.relu3 = nn.ReLU()
        # 最后一个全连接层,用于分类,输出大小为10(假设是10分类问题)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        # 将特征图展平
        x = x.view(-1, 64 * 6 * 6)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x

第二步,准备数据(以 MNIST 数据集为例)

import torchvision
import torchvision.transforms as transforms

# 定义数据转换,将图像转换为张量并进行归一化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# 下载并加载训练数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
# 下载并加载测试数据集
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

第三步,训练模型

# 创建模型实例
model = SimpleCNN()
# 定义损失函数(交叉熵损失)和优化器(随机梯度下降)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练循环
for epoch in range(10):  # 进行10个训练周期
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据和标签
        inputs, labels = data
        # 梯度清零
        optimizer.zero_grad()
        # 前向传播
        outputs = model(inputs)
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 累计损失
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

第四步,测试模型

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # 模型预测
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy of the model on the test set: {100 * correct / total}%')

4 参考资料

神经网络——最易懂最清晰的一篇文章-CSDN博客

  1. 多层感知机(Multilayer Perceptron,MLP)
    • 结构特点:是一种简单的前馈神经网络,由输入层、一个或多个隐藏层和输出层组成。神经元之间全连接,即每个神经元与相邻层的所有神经元都有连接。例如,在一个用于手写数字识别的简单 MLP 中,输入层接收图像像素值,经过隐藏层的非线性变换后,输出层输出各个数字类别对应的概率。
    • 应用场景:广泛应用于分类和回归问题,如简单的图像分类、数据预测等。在自然语言处理领域可用于文本分类,在金融领域用于股票价格预测等。
  2. 卷积神经网络(Convolutional Neural Network,CNN)
    • 结构特点:主要由卷积层、池化层和全连接层组成。卷积层通过卷积核提取数据的局部特征,池化层进行下采样以减少数据维度和计算量,全连接层用于分类或回归等任务。例如在人脸识别任务中,卷积层可以提取人脸五官轮廓等特征。
    • 应用场景:在计算机视觉领域占据主导地位,用于图像分类(如识别图片中的物体是猫还是狗)、目标检测(检测图像中物体的位置和类别)、语义分割(将图像中的每个像素分类到不同语义类别)等。也在音频处理等领域有应用,如语音识别中的声学模型。
  3. 循环神经网络(Recurrent Neural Network,RNN)
    • 结构特点:具有循环连接,能够处理序列数据。在每个时间步,神经元接收当前输入和上一个时间步的隐藏状态,经过处理后输出当前时间步的隐藏状态和预测结果。例如在机器翻译中,RNN 可以逐词处理输入句子和生成翻译后的句子。
    • 应用场景:自然语言处理领域的文本生成、机器翻译、情感分析等任务,以及时间序列预测,如股票走势预测、气象数据预测等。不过,传统 RNN 存在梯度消失和梯度爆炸问题。
  4. 长短期记忆网络(Long - Short Term Memory,LSTM)和门控循环单元(Gated Recurrent Unit,GRU)
    • 结构特点(以 LSTM 为例):是 RNN 的变体,通过特殊的门控机制(输入门、遗忘门和输出门)来控制信息的流动,能够有效解决 RNN 中的梯度消失和梯度爆炸问题,更好地处理长序列数据。例如在长篇小说生成任务中,LSTM 可以有效地利用前文信息生成后续内容。GRU 结构相对更简单,将遗忘门和输入门合并为一个更新门,在性能上和 LSTM 类似,并且计算效率更高。
    • 应用场景:和 RNN 类似,主要用于自然语言处理中的长文本处理、语音识别中的语音序列处理、时间序列分析等需要处理长序列数据的任务。
  5. 生成对抗网络(Generative Adversarial Network,GAN)
    • 结构特点:由生成器和判别器两个神经网络组成。生成器的任务是生成尽可能逼真的数据,判别器的任务是区分真实数据和生成器生成的数据。两者通过对抗训练的方式不断提高性能,最终生成器能够生成高质量的假数据。例如在图像生成任务中,生成器可以根据噪声生成看起来像真实照片的图像。
    • 应用场景:图像生成(如生成高分辨率的风景照片)、数据增强(为训练数据集生成新的样本)、风格迁移(将一种图像风格转换为另一种风格)等。
  6. 自编码器(Auto - Encoder)
    • 结构特点:由编码器和解码器组成。编码器将输入数据压缩成低维的表示(编码),解码器将这个编码还原为尽可能接近原始输入的数据。例如,在图像压缩任务中,编码器将高分辨率图像转换为低维向量,解码器再将这个向量还原为图像。
    • 应用场景:数据降维、图像去噪、特征提取等。例如,在医学影像处理中,可以利用自编码器提取有价值的特征用于疾病诊断。
  7. Transformer 架构
    • 结构特点:基于自注意力机制(Self - Attention),摒弃了传统的循环结构,能够并行计算,大大提高了训练和推理速度。在处理序列数据时,通过计算每个位置与其他位置的相关性来提取特征。例如在自然语言处理中的 BERT 模型,就是基于 Transformer 架构,能够有效捕捉句子中单词之间的语义关系。
    • 应用场景:自然语言处理领域的预训练语言模型(如 GPT 系列、BERT 系列)、机器翻译等任务。在计算机视觉领域也有基于 Transformer 的模型用于图像分类等任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2230288.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

链表详解(一)

目录 顺序表的问题及思考链表链表的概念及结构链表的分类单链表的实现链表功能实现遍历链表void SLTprint(SLNode* phead)代码 创造新节点SLNode* CreateNode(SLNDataType x)代码 顺序表的问题及思考 中间/头部的插入删除,时间复杂度为O(N),效率低,但是尾部插入效率…

【MongoDB】Windows/Docker 下载安装,MongoDB Compass的基本使用、NoSQL、MongoDB的基础概念及基础用法(超详细)

文章目录 Windows下载MongoDB Compass使用NoSQL的基本概念MongoDB常用术语MongoDB与RDBMS区别MongoDB的CRUD 更多相关内容可查看 Docker安装MongoDB可查看:Docker-安装MongoDB Windows下载 官网下载地址:https://www.mongodb.com/try/download/communi…

uni-app发起请求以及请求封装,上传及下载功能(六)

文章目录 一、发起网络请求1.使用及封装2. https 请求配置自签名证书3.拦截器 二、上传下载1.上传 uni.uploadFile(OBJECT)2. 下载 uni.downloadFile(OBJECT) 一、发起网络请求 uni-app中内置的uni.request()已经很强大了,简单且好用。为了让其更好用,同…

C语言 | Leetcode C语言题解之第526题优美的排列

题目&#xff1a; 题解&#xff1a; int countArrangement(int n) {int f[1 << n];memset(f, 0, sizeof(f));f[0] 1;for (int mask 1; mask < (1 << n); mask) {int num __builtin_popcount(mask);for (int i 0; i < n; i) {if (mask & (1 <<…

《Baichuan-Omni》论文精读:第1个7B全模态模型 | 能够同时处理文本、图像、视频和音频输入

技术报告Baichuan-Omni Technical ReportGitHub仓库地址 文章目录 论文摘要1. 引言简介2. 训练2.1. 高质量的多模态数据2.2. 多模态对齐预训练2.2.1. 图像-语言分支2.2.2. 视频语音分支2.2.3. 音频语言分支2.2.4. 图像-视频-音频全方位对齐 2.3. 多模态微调监督 3. 实验3.1. 语…

Cesium着色器

GLSL&#xff08;OpenGL Shading Language&#xff09;是用于编写着色器的语言 着色器类型 顶点着色器&#xff1a;负责处理每个顶点的属性&#xff0c;如位置、颜色等。片段着色器&#xff08;或像素着色器&#xff09;&#xff1a;负责计算最终像素的颜色。 <!DOCTYPE h…

基于SSM医院门诊互联电子病历管理系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;医生管理&#xff0c;项目分类管理&#xff0c;项目信息管理&#xff0c;预约信息管理&#xff0c;检查信息管理&#xff0c;系统管理 用户账号功能包括&#xff1a;系统首页&…

论文阅读笔记-Get To The Point: Summarization with Pointer-Generator Networks

前言 最近看2021ACL的文章&#xff0c;碰到了Copying Mechanism和Coverage mechanism两种技巧&#xff0c;甚是感兴趣的翻阅了原文进行阅读&#xff0c;Copying Mechanism的模型CopyNet已经进行阅读并写了阅读笔记&#xff0c;如下&#xff1a; 论文阅读笔记&#xff1a;Copyi…

unocss 添加支持使用本地 svg 预设图标,并支持更改大小

安装 pnpm install iconify/utils 在配置文件 unocss.config.ts&#xff1a; presets > presetIcons 选项中 通过 FileSystemIconLoader 加载本地图标&#xff0c;并指定目录。 import presetWeapp from unocss-preset-weapp import { extractorAttributify, transformer…

NineData云原生智能数据管理平台新功能发布|2024年10月版

10 月发布内容 本月发布 7 项更新&#xff0c;其中重点发布 2 项、功能优化 3 项、性能优化 1 项、其他发布 1 项。 重点发布​ 数据库 Devops - 数据生成​ NineData 支持在数据库中自动生成符合特定业务场景的随机数据&#xff0c;用于模拟实际生产环境中的数据情况&…

10种数据预处理中的数据泄露模式解析:识别与避免策略

在机器学习教学实践中,我们常会遇到这样一个问题:"模型表现非常出色,准确率超过90%!但当将其提交到隐藏数据集进行测试时,效果却大打折扣。问题出在哪里?"这种情况几乎总是与数据泄露有关。 当测试数据在数据准备阶段无意中泄露(渗透)到训练数据时,就会发生数据泄露…

<十六>Ceph mon 运维

Ceph 集群有故障了&#xff0c;你执行的第一个运维命令是什么&#xff1f; 我猜测是ceph -s 。无论执行的第一个命令是什么&#xff0c;都肯定是先检查Mon。 在开始之前我们有必要介绍下Paxos协议&#xff0c;毕竟Mon就是靠它来实现数据唯一性。 一&#xff1a; Paxos 协议 1…

Spring Boot的核心优势及其应用详解

目录 前言1. Spring Boot的核心优势1.1 启动依赖的集成1.2 自动化配置 2. 内嵌服务器支持2.1 内嵌Tomcat服务器2.2 独立运行与便捷部署 3. 外部配置管理3.1 多环境支持3.2 配置优先级与外部化配置 4. Spring Boot的应用场景4.1 微服务架构4.2 云原生应用 结语 前言 在现代的Ja…

8进制在线编码工具--实现8进制编码

具体前往&#xff1a;文本转八进制在线工具-将文本字符串转换为8进制编码,支持逗号&#xff0c;空格和反斜杠分隔符

Windows 命令提示符(cmd)中输入 mysql 并收到错误消息“MySQL不是内部或外部命令,也不是可运行的程序或批处理文件?

目录 背景: 过程&#xff1a; 1.找到MySQL安装的路径 2.编辑环境变量 3.打开cmd&#xff0c;输入mysql --version测试成功 总结: 背景: 很早之前安装了Mysql数据库&#xff0c;想查询一下当前安装的MySQL客户端的版本号&#xff0c;我在命令行界面输入mysql --verion命令回…

Python学习的自我理解和想法(22)

学的是b站的课程&#xff08;千锋教育&#xff09;&#xff0c;跟老师写程序&#xff0c;不是自创的代码&#xff01; 今天是学Python的第22天&#xff0c;学的内容是正则表达式&#xff0c;明天会出一篇详细实例介绍。电脑刚修好&#xff01;开学了&#xff0c;时间不多&…

大数据-203 数据挖掘 机器学习理论 - 决策树 sklearn 剪枝参数 样本不均匀问题

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&#xff09;MapReduce&#xff08;已更完&am…

报错:npm : 无法加载文件 C:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本。

报错场景 使用npm run dev 报错 npm : 无法加载文件 C:\Program Files\nodejs\npm.ps1&#xff0c;因为在此系统上禁止运行脚本。有关详细信息&#xff0c;请参阅 https:/go.microsoft.com/fwlink/?LinkID135170 中的 about_Execution_Policies。 所在位置 行:1 字符: 1 npm…

Python基于TensorFlow实现双向循环神经网络GRU加注意力机制分类模型(BiGRU-Attention分类算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后关注获取。 1.项目背景 随着深度学习技术的发展&#xff0c;循环神经网络&#xff08;RNN&#xff09;及其变种如门控循环…

从APP小游戏到Web漏洞的发现

一、前因&#xff1a; 在对一次公司的一个麻将游戏APP进行渗透测试的时候发现&#xff0c;抓到HTTP请求的接口&#xff0c;但是反编译APK后发现没有在本身发现任何一个关于接口或者域名相关的关键字&#xff0c;对此感到了好奇。 于是直接解压后everything搜索了一下&#xff…