从PyTorch官方的一篇教程说开去(5 - 神经网络 torch.nn)

news2025/1/17 4:08:38

神经网络长啥样?有没有四只眼睛八条腿?

借图镇楼 - 

真的是非常经典,可以给下面的解释省掉很多力气。

分3个维度阐述 - 

1)输入数据集。假如你自己去微调一下大模型就知道,最开始的一步就是要准备(足够大的)数据集,比如百度就要求1kw条+的数据集,否则就不给你训练。PyTorch官方的数据集是用的著名的FashionMNIST,这是一个由许多28pixel * 28pixel的黑白像素组成的图片(img),类似于windows中的缩略图,每一个图片有标记(label),以及序号(idx)。

感兴趣的您也可以自行下载,下载来的是压缩格式.gz,里面图片是用PIL格式存储的单一二进制文件。如果想看看这些黑白缩略图长啥样呢,可以下载和渲染看看 - 

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="drive/MyDrive/",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="drive/MyDrive/",
    train=False,
    download=True,
    transform=ToTensor()
)

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    print(sample_idx)
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2)输入有了。对于以上数据集,那就是一大堆784的数组或者列表。你需要把这个转化为1x784的tensor张量或者矢量或者向量,被称为flatten层,也就是把所有信息一维化或者叫扁平化。tensor这个词我们回头专门讨论一下,现在理解为多维数组就好。

那接下来整几层网络?分别干啥用?每一层多少个神经元?最后的输出咋整?

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512), 
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

y=nn.Linear(28*28, 512)
print(y,"\n")
z=nn.ReLU(y)
print(z,"\n")

以上是官方的代码,初始化第二层用了512,这个用来存储输出的特征数量,那么784个数用512个特征是否足够表达呢?会不会欠拟合or过拟合呢?这个就完全是经验值了,或者说,要根据你特定的场景或者模型来跑跑看了。反正本例中,每个输入图像通过这个层后,会被转换成一个包含512个数值的向量。中间的隐藏层,是为了方便把输入过渡到输出,你把数字换成1024也完全没有毛病(有可能你的显卡会对你提出抗议,给你穿个小鞋啥的)

本例的输出是为了统计图像的类别,类别一共有10类,从0-T-shirt到9-踝靴,所以输出用10位来表达足够,这里采用了一种被成为one-hot编码的常用的数据表示方法,One-hot编码将分类数据表示为二进制(0和1)的格式,其中每个类别由一个唯一的二进制向量表示,该向量的长度等于类别的总数。例如,假设我们有一个包含三个可能类别的数据集:A、B和C。我们可以用以下one-hot编码表示它们:

  • 类别 A: [1, 0, 0]
  • 类别 B: [0, 1, 0]
  • 类别 C: [0, 0, 1]

3)综上,我们根据输入输出的形式,设计了一个4层(线性化+输入层,隐藏中间层1,隐藏中间层2,输出层)。现在问题来了,这个稚嫩的神经网络怎么干活呢?它怎么不断的学习和奖惩,从而准确的看到个图就报出类别名呢?

因为用了torch库,所以一些细节在代码中被隐藏了。

比如,为了把784个像素和512个特征值对应起来,其实隐含生成了一个权重矩阵和偏置向量。权重是一个矩阵,其行数等于输出特征数(512),列数等于输入特征数(784)。偏置是一个长度为输出特征数(512)的向量。

这个就和上一章节中我们讨论过的玩游戏的AI的脑子对应上了。我们需要初始化权重矩阵(全0或者全1),和偏置向量(从0或者1开始)。然后使用贪婪算法和梯度下降算法,不断的迭代优化。

呃,还漏了一点,激活函数ReLU(),这个是为了让我们的初始化和迭代过程更加平滑一点,避免0或者1产生一些诸如“死亡神经元”之类的算法缺陷(即当输入为负时,梯度为零,导致部分神经元停止更新。)。激活函数有很多种,默认的是f(x) = max(0, x),也就是在x大于0时直接输出x,小于0时输出0。它简单且计算效率高,是当前最流行的激活函数之一,尤其适用于隐藏层。其实可以选择的激活函数还有很多,比如Tanh,Exponential,Sigmoid等等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一个项目的坎坷一生

大家好,我是苍何。 目前呢,主要是负责部门的项目管理和团队管理相关工作,今天想和大家分享一下企业级标准的项目管理流程以及苍何的实践。 通过本文,能帮助你更快的在企业中上手项目并定位好自己的角色,别人一脸懵逼…

高效恢复误删文件:2024年数据恢复工具

现在都是互联网的时代,数据已成为我们生活与工作中不可或缺的一部分。很多时候我们都依赖电子设备来存储数据,这也造成来数据丢失风险的增加。这时候如果掌握了一些数据恢复软件,比如转转大师数据恢复软件这种就能对你的电子存储数据有保障。…

如何在 Windows 系统环境下安装 Tesseract OCR? ( •̀ ω •́ )✧

第一步:下载Tesseract OCR引擎安装包 🍑 访问Tesseract的GitHub发布页面(https://github.com/tesseract-ocr/tesseract)或第三方下载站点(https://digi.bib.uni-mannheim.de/tesseract/),下载适…

Docker与LXC差异以及相关命令

容器:Docker与LXC差异以及相关命令 ​ LXC与Docker对比,LXC只实现了进程沙盒化,不支持在不同的机器上进行移植;Docker将应用的所有配置和环境进行了抽象,打包到一个容器中,此容器可以在任何安装了docker的…

mybatis-plus实现分页功能

第一步:添加mybatis-plus为分页所使用的拦截器插件 (不用这个的话sql里面的limit关键字无法实现,也就没办法实现查询操作) 代码: Configuration public class mybatis_plus_config {Beanpublic MybatisPlusIntercept…

4大类75项BUG场景大盘点!测试人必看!

本文主要针对填写BUG时,bug分类共分为多少项,每一项内容都有哪些场景,并结合具体错误案例进行简单分析。 一UI表示层 在软件测试和开发中,当提到“用户UI”类型的bug时,通常是指与用户界面(User Interface…

Weights2wights Interpreting the Weight Space of Customized Diffusion Models

Weights2wights: Interpreting the Weight Space of Customized Diffusion Models 导语 可控生成是图像生成领域的一个重要方向。从最基础的文本条件生成,到 ControlNet、IP-Adapter 等图像条件生成,再到各种概念定制化生成,扩散模型的可控…

InternLM Git 基础知识

提交一份自我介绍。 创建并提交一个项目。

采用GDAL批量波段运算计算植被指数0基础教程

采用GDAL批量波段运算计算植被指数0基础教程 1. 引言 在传统的遥感数据处理方法中,通常使用ArcGis或ENVI软件进行波段运算。然而,这些软件在处理大量数据时往往效率低下。有没有一种方法可以批量进行波段运算,一下子计算几十个植被指数&…

将项目部署到docker容器上

通过docker部署前后端项目 前置条件 需要在docker中拉去jdk镜像、nginx镜像 docker pull openjdk:17 #拉取openjdk17镜像 docker pull nginx #拉取nginx镜像部署后端 1.打包后端项目 点击maven插件下面的Lifecycle的package 对后端项目进行打包 等待打包完成即可 2.将打…

【全志H616开发】Linux的热拔插UDEV机制

文章目录 udev简介工作原理Udev 配置文件和规则示例总结: udev简介 Udev 是 Linux 系统中设备管理的一部分,它负责管理动态设备节点并处理设备的热插拔。Udev 提供了一种在用户空间管理设备节点的机制,可以在设备插入或移除时自动执行相应的…

2024.7.30问题合集

2024.7.30问题合集 1.adb调试出现5037端口被占用的情况2.更改ip地址时出现以下问题3.RV1126 ip配置问题 1.adb调试出现5037端口被占用的情况 问题:5037端口被占用的情况 解决方案:将adb文件下的adb.exe和AdbWinApi.dll两个文件复制到C:\Windows\SysWOW6…

红外热成像仪的功能应用_鼎跃安全

红外热成像仪利用红外探测器接收被测目标物体发射的红外辐射能量;通过接收到红外辐射转化为电信号,将这些信号放大转化后,通过不同的颜色代表不同温度,从而直观的在电子屏显示出来,可以清晰的观察到物体的热分布。 热成…

flex/bison结合使用解析配置文件

flex是gnu linux下的语法分析器程序(lex则是Unix下的语法分析器),它将输入文件(yyin)的内容去匹配对应的匹配规则表达式,并返回一个token。注意,flex的copyright并不是gnu的。 bison是gnu linux下的yacc(Yet Another Compiler Compiler)&…

【计算机毕设论文】基于SpringBoot的成绩管理系统

💗博主介绍:✌全平台粉丝5W,高级大厂开发程序员😃,博客之星、掘金/知乎/华为云/阿里云等平台优质作者。 【源码获取】关注并且私信我 感兴趣的可以先收藏起来,同学门有不懂的毕设选题,项目以及论文编写等相…

3.5.4、查找和排序算法-排序算法下

快速排序 快速排序的基本思想是:通过一趟排序将待排的序列划分为独立的两个部分,其中一部分序列的元素均不大于另一部分记录的关键字,然后再分别对这两部分序列继续进行快速排序,以达到整个序列有序。 大致步骤如下:…

2024西安铁一中集训DAY27 ---- 模拟赛((bfs,dp) + 整体二分 + 线段树合并 + (扫描线 + 线段树))

文章目录 前言时间安排及成绩题解A. 倒水(bfs dp)B. 让他们连通(整体二分 按秩合并并查集 / kruskal重构树)C. 通信网络(线段树合并 二分)D. 3SUM(扫描线 线段树) 前言 T1没做出…

6万字,让你轻松上手的大模型 LangChain 框架

本文为我学习 LangChain 时对官方文档以及一系列资料进行一些总结~覆盖对Langchain的核心六大模块的理解与核心使用方法,全文篇幅较长,共计50000字,可先码住辅助用于学习Langchain。** 一、Langchain是什么? 如今各类…

FPGA实现LCD12864控制

目录 注意! a) 本工程采用野火征途PRO开发板,外接LCD12864部件进行测试。 b) 有偿提供代码!!!可以定制功能!!!有需要私信!!! c) 本文测试采用…

操作系统02

文章目录 Linux 内核 vs Windows 内核内核Linux 的设计MultiTaskSMPELFMonolithic Kernel **Windows 设计** 内存管理虚拟内存内存分段内存分页多级页表TLB 段页式内存管理Linux 内存布局内存分配的过程是怎样的?哪些内存可以被回收?回收内存带来的性能影…