如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识

news2024/11/15 20:51:59
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 定义MLP模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        # 创建一个顺序的层序列:包括一个扁平化层、两个全连接层和ReLU激活
        self.layers = nn.Sequential(
            nn.Flatten(),                       # 将28x28的图像扁平化为784维向量
            nn.Linear(28 * 28, 512),            # 第一个全连接层,784->512
            nn.ReLU(),                          # ReLU激活函数
            nn.Linear(512, 256),                # 第二个全连接层,512->256
            nn.ReLU(),                          # ReLU激活函数
            nn.Linear(256, 10)                  # 第三个全连接层,256->10 (输出10个类别)
        )
        
    def forward(self, x):
        return self.layers(x)                   # 定义前向传播

# 加载FashionMNIST数据集
# 定义图像的预处理:转换为Tensor并标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 下载FashionMNIST数据并应用转换
dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)

# 划分数据集为训练集和验证集
train_len = int(0.8 * len(dataset))           # 计算80%的长度作为训练数据
val_len = len(dataset) - train_len            # 剩下的20%作为验证数据
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # 训练数据加载器,批量大小64,打乱数据
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)     # 验证数据加载器,批量大小64,不打乱

# 初始化模型、损失函数和优化器
model = MLP()                                 # 创建MLP模型实例
criterion = nn.CrossEntropyLoss()             # 定义交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用Adam优化器

# 训练模型
epochs = 5                                    # 定义训练5个epochs
for epoch in range(epochs):
    model.train()                             # 将模型设置为训练模式
    for inputs, labels in train_loader:       # 从训练加载器中获取批次数据
        outputs = model(inputs)               # 前向传播
        loss = criterion(outputs, labels)     # 计算损失
        optimizer.zero_grad()                 # 清除之前的梯度
        loss.backward()                       # 反向传播,计算梯度
        optimizer.step()                      # 更新权重
        
    # 在每个epoch结束时验证模型性能
    model.eval()                              # 将模型设置为评估模式
    total_correct = 0
    with torch.no_grad():                     # 不计算梯度,节省内存和计算量
        for inputs, labels in val_loader:     # 从验证加载器中获取批次数据
            outputs = model(inputs)           # 前向传播
            _, predicted = outputs.max(1)     # 获取预测的类别
            total_correct += (predicted == labels).sum().item()  # 统计正确的预测数量
    accuracy = total_correct / val_len        # 计算验证准确性
    print(f"Epoch {epoch + 1}/{epochs} - Validation accuracy: {accuracy:.4f}")  # 打印验证准确性

nn.Flatten() 是一个特殊的层,它将多维的输入数据“展平”为一维数据。这在处理图像数据时尤为常见,因为图像通常是多维的(例如,一个大小为28x28的灰度图像在PyTorch中会有一个形状为[28, 28]的张量)。

在神经网络的某些层,特别是全连接层(如nn.Linear)之前,通常需要对数据进行扁平化处理。因为全连接层期望其输入是一维的(或者更准确地说,它期望输入的最后一个维度对应于特征,其他维度对应于数据的批次)。

为了更具体,让我们看一个例子:

考虑一个大小为[batch_size, 28, 28]的张量,这可以看作是一个batch_size数量的28x28图像的批次。当我们传递这个批次的图像到一个nn.Linear(28*28, 512)层时,我们需要先将图像展平。也就是说,每个28x28的图像需要转换为长度为784的一维向量。因此,输入数据的形状会从[batch_size, 28, 28]变为[batch_size, 784]。

nn.Flatten()就是做这个转换的。在这个特定的例子中,它会将[batch_size, 28, 28]的形状转换为[batch_size, 784]。

总结一下:nn.Flatten()用于将多维输入数据转换为一维,从而使其可以作为全连接层(如nn.Linear)的输入。

  • transforms.Compose:
    这是一个简单的方式来链接(组合)多个图像转换操作。它会按照提供的顺序执行列表中的每个转换。

  • transforms.ToTensor():
    这个转换将PIL图像或NumPy的ndarray转换为FloatTensor。并且它将图像的像素值范围从0-255变为0-1。简言之,它为我们完成了数据类型和值范围的转换。

  • transforms.Normalize((0.5,), (0.5,)):
    这个转换标准化张量图像。给定的参数是均值和标准差。在这里,均值和标准差都是0.5。
    使用给定的均值和标准差,这会将值范围从[0,1]转换为[-1,1]。

整个transform的目的是:

  • 将图像数据从PIL格式转换为PyTorch张量格式。
  • 将像素值从[0,255]范围转换为[0,1]范围。
  • 使用给定的均值和标准差进一步标准化像素值,使其范围为[-1,1]。

初始化模型、损失函数和优化器

  • model = MLP():

    • 这里我们实例化了我们之前定义的MLP类,从而创建了一个多层感知器(MLP)模型。
  • criterion = nn.CrossEntropyLoss():

    • 在分类任务中,交叉熵损失函数 (CrossEntropyLoss) 是最常用的损失函数之一。它衡量真实标签和预测之间的差异。
    • 注意:CrossEntropyLoss在内部执行softmax操作,因此模型输出应该是未经softmax处理的原始分数(logits)。
  • optimizer = optim.Adam(model.parameters(), lr=0.001):

    • 优化器负责更新模型的权重,基于计算的梯度来减少损失。
    • Adam是一种流行的优化器,它结合了两种扩展的随机梯度下降:Adaptive Gradients 和 Momentum。
    • model.parameters()是传递给优化器的,它告诉优化器应该优化/更新哪些权重。
    • lr=0.001定义了学习率,这是一个超参数,表示每次权重更新的步长大小。

常见的相关资料解答

  1. 模型 (在torch.nn中):

除了基本的MLP外,PyTorch提供了很多预定义的层和模型,常见的包括:

Convolutional Neural Networks (CNNs):
    nn.Conv2d: 2D卷积层,常用于图像处理。
    nn.Conv3d: 3D卷积层,常用于视频处理或医学图像。
    nn.MaxPool2d: 最大池化层。

Recurrent Neural Networks (RNNs):
    nn.RNN: 基本的RNN层。
    nn.LSTM: 长短时记忆网络。
    nn.GRU: 门控循环单元。

Transformer Architecture:
    nn.Transformer: 用于自然语言处理任务的Transformer模型。

Batch Normalization, Dropout等:
    nn.BatchNorm2d: 批量归一化。
    nn.Dropout: 防止过拟合的正则化方法。
  1. 损失函数 (在torch.nn中):

常见的损失函数有:

Classification:
    nn.CrossEntropyLoss: 用于分类任务的交叉熵损失。
    nn.BCEWithLogitsLoss: 用于二分类任务的二元交叉熵损失,包括内部的sigmoid操作。
    nn.MultiLabelSoftMarginLoss: 用于多标签分类任务。

Regression:
    nn.MSELoss: 均方误差,用于回归任务。
    nn.L1Loss: L1误差。

Generative models:
    nn.KLDivLoss: Kullback-Leibler散度,常用于生成模型。
  1. 优化器 (在torch.optim中):

常见的优化器有:

optim.SGD: 随机梯度下降。
optim.Adam: 一个非常受欢迎的优化器,结合了AdaGrad和RMSProp的特点。
optim.RMSprop: 常用于深度学习任务。
optim.Adagrad: 自适应学习率优化器。
optim.Adadelta: 类似于Adagrad,但试图解决其快速降低学习率的问题。
optim.AdamW: Adam的变种,加入了权重衰减。

在这里插入图片描述

每文一语

学习是不断的发展的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1106979.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Qt QSlider滑动条小项目

QSlider 是滑动条控件,滑动条可以在一个范围内拖动,并将其位置转换为整数 1. 属性和方法 QSlider 继承自 QAbstractSlider,它的绝大多数属性都是从 QAbstractSlider 继承而来的。 2.QSlider信号 - `valueChanged(int value)`: 当滑块的值改变时发出信号,传递当前滑块的值…

mysql检验分区性能的操作

mysql检验分区性能的操作 创建两个结构相同但是一个有分区另外一个没有分区的表 如上图我们给part_tab5创建的分区为1024个,因为mysql中允许最多有1024个分区;之前我测试的是创建8个分区,然后插入500万条数据,然后按照id查询&…

关于页面优化

一、 js优化 js文件内部 1、减少重复代码的使用,精简代码 2、减少请求次数,如果不是需要实时的数据,可以将请求结果缓存在js变量中,后续直接使用变量的值 3、减少不必要的dom操作,例如:用innerHTMl代替do…

小魔推短视频裂变工具,如何帮助实体行业降本增效?

在如今的互联网时代,大多数的实体老板都在寻找不同的宣传方法来吸引客户,现在短视频平台已经成为重中之重的获客渠道之一,而如何在这个日活用户超7亿的平台获取客户,让更多人知道自己的门店、自己的品牌,泽成为了不少老…

uniapp vue3 使用pinia存储数据

import { defineStore } from pinia;export const userInfo defineStore(userInfo, {state: () > {return {userToken: uni.getStorageSync(token) || ,};},actions: {// 添加tokenupdateToken(token: string) {uni.setStorageSync(token, token);this.userToken token}} …

Apache Doris (四十三): Doris数据更新与删除 - Update数据更新

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 目录 1. Update数据更新原理

全面解决找不到vcruntime140_1.dll无法执行此代码问题的5方法

vcruntime140_1.dll是一个动态链接库文件,它是Microsoft Visual C 2015 Redistributable的一部分。当计算机中缺少这个文件时,可能会导致一些应用程序无法正常运行,从而影响我们的工作和生活。 一、问题场景 1. 在使用Windows操作系统的过程…

QTday02(常用类、UI界面下的开发、信号与槽)

今日任务 1. 使用手动连接,将登录框中的取消按钮使用qt4版本的连接到自定义的槽函数中,在自定义的槽函数中调用关闭函数 将登录按钮使用qt5版本的连接到自定义的槽函数中,在槽函数中判断ui界面上输入的账号是否为"admin"&#x…

出差学小白知识No5:ubuntu连接开发板|上传源码包|板端运行的环境部署

1、ubuntu连接开发板&#xff1a; 在ubuntu终端通过ssh协议来连接开发板&#xff0c;例如&#xff1a; ssh root<IP_address> 即可 这篇文章中也有关于如何连接开发板的介绍&#xff0c;可以参考SOC侧跨域实现DDS通信总结 2、源码包上传 通过scp指令&#xff0c;在ub…

1018hw

#include "mainwindow.h" #include "ui_mainwindow.h"MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this);//窗口名this->setWindowTitle("QQ");this->setWindowIcon(QIcon(&q…

位段——(详细图解,保姆宗师级教程,包会,从基础概念到精通实战应用)

位段——大项目中结构体节省空间之手段 学习目标&#xff1a; 位段是什么 位段的内存分配 位段的平台局限性和应用 学习内容&#xff1a; 1.位段是什么 C中的位段&#xff08;Bit fields&#xff09;是一种用于有效利用内存的特性&#xff0c;可以在结构体中定义成员变量的…

朴素贝叶斯(基于概率论)

释义 贝叶斯定理是“由果溯因”的推断&#xff0c;所以计算的是"后验概率" 其中&#xff1a; P(A|B) 表示在事件 B 已经发生的条件下&#xff0c;事件 A 发生的概率。 P(B|A) 表示在事件 A 已经发生的条件下&#xff0c;事件 B 发生的概率。 P(A) 和 P(B) 分别表示事…

贴片电阻材质:了解电子元件的核心构成 | 百能云芯

在现代电子设备中&#xff0c;贴片电阻是一类至关重要的 passives 元件&#xff0c;广泛用于各种电路和应用中。贴片电阻的性能取决于多个因素&#xff0c;其中材质是其中之一。云芯将带您深入探讨贴片电阻的不同材质&#xff0c;探讨不同材质对电子元件性能的影响&#xff0c;…

深入理解算法:从基础到实践

深入理解算法&#xff1a;从基础到实践 1. 算法的定义2. 算法的特性3. 算法的分类按解决问题的性质分类&#xff1a;按算法的设计思路分类&#xff1a; 4. 算法分析5. 算法示例a. 搜索算法示例&#xff1a;二分搜索b. 排序算法示例&#xff1a;快速排序c. 动态规划示例&#xf…

【考研数学】概率论与数理统计 —— 第六章 | 数理统计基本概念(1,基本概念)

文章目录 引言一、基本概念1.1 总体1.2 样本1.3 统计量1.4 顺序统计量 写在最后 引言 以前学概率论的时候&#xff0c;不知道后面的数理统计是什么&#xff0c;所以简称都把后面的省略掉了。现在接触的学科知识多了&#xff0c;慢慢就对数理统计有了直观印象。 尤其是第一次参…

刷题日记1

最近在用JavaScript刷动态规划的题组&#xff0c;刷了一半感觉只刷题不写笔记的话印象没那么深刻&#xff0c;所以从今天开始来记录一下刷题情况。 力扣T300 300. 最长递增子序列 给你一个整数数组 nums &#xff0c;找到其中最长严格递增子序列的长度。 子序列 是由数组派生而…

超实用!了解github的热门趋势和star排行是必须得!

在当今的技术领域中&#xff0c;GitHub 已经成为了开发者们分享和探索代码的重要平台。作为全球最大的开源社区&#xff0c;GitHub上托管了数以亿计的项目&#xff0c;其中包括了各种各样的技术栈和应用。对于开发者来说&#xff0c;了解GitHub上的热门趋势和star排行是非常重要…

Java10年技术架构演进

一、前言 又快到了1024&#xff0c;现代人都喜欢以日期的特殊含义来纪念属于自己的节日。虽然有点牵强&#xff0c;但是做件事情&#xff0c;中国人总喜欢找个节日来纪念&#xff0c;程序员也是一样。甚至连1111被定义成光棍节&#xff0c;这也算再无聊不过了。不过作为程序员…

基于百度API的车牌识别计费系统

1&#xff0c;车牌识别API 介绍&#xff1a; 百度车牌识别API是一款基于人工智能算法的车牌识别服务&#xff0c;可以识别包括普通车牌、新能源车牌在内的多种车牌类型&#xff0c;并支持高精度的识别结果输出。其主要功能特点包括&#xff1a; 普通车牌和新能源车牌的识别&a…

首发AI原生应用开发平台——千帆AI原生应用开发工作台,加速企业AI应用落地

为了满足企业对于敏捷和高效地进行AI原生应用开发与运维的需求&#xff0c;并降低相关开发的门槛&#xff0c;百度智能云最新发布了“千帆AI原生应用开发工作台”。该工作台将开发大型模型应用程序的常见模式、工具和流程进行了整合&#xff0c;使得开发者可以聚焦于自身业务&a…