昇思25天学习打卡营第5天 | 神经网络构建

news2025/1/11 21:09:12

1. 神经网络构建

神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子Cell构成。使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。

下面我们将构建一个用于Mnist数据集分类的神经网络模型。

# 导入MindSpore框架,这是一个面向AI应用的全场景深度学习框架
import mindspore

# 从MindSpore框架中导入nn(神经网络相关模块)和ops(操作相关模块)
# nn模块包含构建神经网络所需的各种层和激活函数等
# ops模块提供了执行计算任务的各种操作,例如数学运算和数组操作等
from mindspore import nn, ops

1.1 定义模型类

当我们定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。

construct意为神经网络(计算图)构建,相关内容详见使用静态图加速。

# 定义一个Network类,继承自nn.Cell,Cell是MindSpore中构建网络的基类
class Network(nn.Cell):
    # 定义初始化方法,用于完成神经网络的构建
    def __init__(self):
    	# 调用父类的构造函数
        super().__init__()  
        # 创建一个Flatten层,用于将输入展平为一维向量
        self.flatten = nn.Flatten()  
        # 创建一个SequentialCell,它将多个层按顺序组合起来
        self.dense_relu_sequential = nn.SequentialCell(
            # 第一个全连接层,输入节点(神经元)数为28*28(假设是MNIST图像的大小),输出节点数为512
            # 权重初始化为正态分布,偏置初始化为0
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            # ReLU激活函数
            nn.ReLU(),  
            # 第二个全连接层,512个输入节点,512个输出节点
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            # ReLU激活函数
            nn.ReLU(),  
            # 第三个全连接层,512个输入节点,10个输出节点(假设是10个类别)
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    # construct方法是实现Cell向前传播的逻辑
    def construct(self, x):
        # 使用Flatten层将输入x展平
        x = self.flatten(x)  
        # 将展平后的x通过全连接层和激活函数
        logits = self.dense_relu_sequential(x)  
        # 返回最后的logits,即模型的输出
        return logits  

# 创建Network类的实例,这将初始化神经网络模型的结构
model = Network()

# 打印模型实例,这将输出模型的结构信息,包括各个层的类型和参数
print(model)

输出:

Network<
  (flatten): Flatten<>
  (dense_relu_sequential): SequentialCell<
    (0): Dense<input_channels=784, output_channels=512, has_bias=True>
    (1): ReLU<>
    (2): Dense<input_channels=512, output_channels=512, has_bias=True>
    (3): ReLU<>
    (4): Dense<input_channels=512, output_channels=10, has_bias=True>
    >
  >

为了方便理解,我把这个神经网络画出来了
network

为了方便理解,把tensor也画进去了,实际上,这个网络结构定义好的时候是不包含具体的向量值的。
只是一个算法结构定义,flatten层和神经元内都是没有值的。
只有在训练或预测时,传入了Tensor值时,才会如上图所示
并且Tensor基本上会做过数据变换,传入易于神经网络计算的值,不然全是正值的话,ReLU层就没有存在的必要了。
另外,模型在被创建时,一般会完成参数的初始化,即图上所有的a(mn)和b(mn)都会有初始值。
所谓对模型的训练,就是不断调整这些初始参数,使得他们能在尽可能多的情况下满足给定输入可以计算出预期输出的过程。
而模型推理指的是,使用已练训练好的这些模型参数和整个模型算法结构,给定输入后,推测出结果的过程。

1.2 使用模型进行预测

# 使用ops模块的ones函数创建一个形状为(1, 28, 28)的张量,并指定数据类型为float32
# 这个张量可以代表一个单通道的28x28像素的图像,所有元素都初始化为1
X = ops.ones((1, 28, 28), mindspore.float32)

# 将创建的张量X作为输入传递给模型model,计算前向传播的输出
logits = model(X)

# 打印模型输出的logits,这通常代表了模型对输入数据的预测结果
## 在Python交互式环境(如Python shell或Jupyter notebook)中,当你输入一个变量名并按下回车时,系统会自动调用该变量的__repr__或__str__方法来打印出一个可读的字符串表示,以便于用户查看变量的内容。
## 如果你在脚本文件中执行这段代码,而没有使用print函数,变量logits的内容将不会被打印出来。
logits

输出:

Tensor(shape=[1, 10], dtype=Float32, value=
[[ 4.55602678e-03, -1.45907002e-02,  1.84002449e-03 ...  8.69915355e-03,  6.90837333e-04, -4.99743177e-03]])
# 使用nn模块中的Softmax函数计算logits的softmax概率分布
# Softmax函数通常用于多分类问题,将 logits 转换为概率分布
# 通俗理解就是传入一张图片,Softmax会将这张图片是0-9每个数字(代表10个类别)的概率都计算出来。
# 参数axis=1表示在第二个维度(索引为1)上应用softmax,即对每个样本的所有类别的logits进行归一化
pred_probab = nn.Softmax(axis=1)(logits)

# 使用argmax函数找到每个样本在概率分布中最大值的索引,概率最大的类别,即预测的类别
# argmax函数的参数1表示在第二个维度(索引为1)上找到最大值的索引
y_pred = pred_probab.argmax(1)

# 打印预测的类别
# f-string是一种格式化字符串的语法,可以在字符串中嵌入表达式
# {y_pred}会被替换为y_pred变量的值
print(f"Predicted class: {y_pred}")

输出:

Predicted class: [4]

1.3 模型层详解

本节中我们分解上节构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。

# 使用ops模块的ones函数创建一个形状为(3, 28, 28)的张量,并指定数据类型为float32
# 这个张量可以代表一个3通道的28x28像素的图像,所有元素都初始化为1
input_image = ops.ones((3, 28, 28), mindspore.float32)

# 打印输入图像张量的形状
print(input_image.shape)

输出:

(3, 28, 28)

1.3.1 nn.Flatten(展平层)

实例化nn.Flatten层,将28x28的2D张量转换为784大小的连续数组。

# 创建一个Flatten层实例,用于将输入张量展平为一维向量
flatten = nn.Flatten()

# 使用Flatten层处理输入图像张量input_image,将其展平为一维向量
# 由于输入图像是3个通道的28x28像素,展平后每个通道将变成一个784长度的一维向量
flat_image = flatten(input_image)

# 打印展平后图像张量的形状
# 输出将是(3, 784),表示有三个一维向量,每个向量长度为784
print(flat_image.shape)

输出:

(3, 784)

1.3.2 nn.Dense(全连接层)

nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。

# 创建一个全连接层layer1,其中in_channels参数设置为28*28,out_channels参数设置为20
# 这意味着该层将有28*28个输入节点和20个输出节点
layer1 = nn.Dense(in_channels=28*28, out_channels=20)

# 将展平后的图像flat_image作为输入传递给全连接层layer1
# 计算全连接层的输出,这将是20个特征的一维向量,每个向量对应于输入图像的一个样本
hidden1 = layer1(flat_image)

# 打印全连接层输出的形状和内容
# 输出形状将是(3, 20),表示有三个样本,每个样本有20个特征
# 输出内容将是这些特征的值
print(hidden1.shape, hidden1)

输出:

(3, 20) [[ 0.5608331  -0.06500022  0.5195999   0.45284656  0.22346526 -0.22476278
  -0.340589   -0.43673825 -0.57077926 -0.44966274  0.3863637   0.52841353
  -0.44325724  1.1107857  -1.2462549  -0.17119673  0.46310893 -0.8667695
  -0.204903    0.0104395 ]
 [ 0.5608331  -0.06500022  0.5195999   0.45284656  0.22346526 -0.22476278
  -0.340589   -0.43673825 -0.57077926 -0.44966274  0.3863637   0.52841353
  -0.44325724  1.1107857  -1.2462549  -0.17119673  0.46310893 -0.8667695
  -0.204903    0.0104395 ]
 [ 0.5608331  -0.06500022  0.5195999   0.45284656  0.22346526 -0.22476278
  -0.340589   -0.43673825 -0.57077926 -0.44966274  0.3863637   0.52841353
  -0.44325724  1.1107857  -1.2462549  -0.17119673  0.46310893 -0.8667695
  -0.204903    0.0104395 ]]

1.3.3 nn.ReLU(激活函数)

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

# 使用ReLU激活函数对全连接层的输出hidden1进行非线性处理
# ReLU激活函数将替换hidden1中的所有负值元素为0
hidden1 = nn.ReLU()(hidden1)

# 打印经过ReLU激活函数处理后的输出
# 输出将是经过ReLU处理的特征值,其中所有负值已经被替换为0
print(f"After ReLU: {hidden1}")

输出:

After ReLU: [[0.5608331  0.         0.5195999  0.45284656 0.22346526 0.
  0.         0.         0.         0.         0.3863637  0.52841353
  0.         1.1107857  0.         0.         0.46310893 0.
  0.         0.0104395 ]
 [0.5608331  0.         0.5195999  0.45284656 0.22346526 0.
  0.         0.         0.         0.         0.3863637  0.52841353
  0.         1.1107857  0.         0.         0.46310893 0.
  0.         0.0104395 ]
 [0.5608331  0.         0.5195999  0.45284656 0.22346526 0.
  0.         0.         0.         0.         0.3863637  0.52841353
  0.         1.1107857  0.         0.         0.46310893 0.
  0.         0.0104395 ]]

可以看到,对比激活函数应用之前,所有的负数都被替换成了0

1.3.4 nn.SequentialCell(有序神经网络)

nn.SequentialCell是一个有序的Cell容器。输入Tensor将按照定义的顺序通过所有Cell。我们可以使用SequentialCell来快速组合构造一个神经网络模型。

# 创建一个SequentialCell实例seq_modules,它将按照顺序组合多个层和激活函数
# 第一个元素是之前定义的Flatten层,用于展平输入图像
# 第二个元素是之前定义的全连接层layer1,它将展平后的图像映射到20个特征
# 第三个元素是ReLU激活函数,用于对layer1的输出进行非线性处理
# 第四个元素是另一个全连接层,它将20个特征映射到10个输出节点,通常用于分类任务的最后一层
seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

# 将输入图像input_image通过seq_modules进行处理
# 这将依次通过Flatten层、layer1、ReLU激活函数和最后一个全连接层
logits = seq_modules(input_image)

# 打印经过seq_modules处理后的输出logits的形状
# 输出形状将是(3, 10),表示有三个样本,每个样本有10个输出节点
print(logits.shape)

输出:

(3, 10)

1.3.4 nn.Softmax(概率分布函数)

最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

# 创建一个Softmax实例,用于计算softmax概率分布
# 参数axis=1表示在第二个维度(索引为1)上应用softmax
# 这通常用于多分类问题,将logits转换为概率分布
softmax = nn.Softmax(axis=1)

# 使用softmax函数计算logits的softmax概率分布
# 这将输出每个类别对应的概率,概率总和为1
pred_probab = softmax(logits)

1.4 模型参数

网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

# 打印模型的结构,这将显示模型的层次结构和参数信息
print(f"Model structure: {model}\n\n")

# 遍历模型中的所有参数和名称
for name, param in model.parameters_and_names():
    # 打印每个层的名称
    print(f"Layer: {name}\n")
    
    # 打印每个参数的形状
    print(f"Size: {param.shape}\n")
    
    # 打印每个参数的前两个值,用于查看参数的初始化情况
    print(f"Values : {param[:2]} \n")

输出:

Model structure: Network<
  (flatten): Flatten<>
  (dense_relu_sequential): SequentialCell<
    (0): Dense<input_channels=784, output_channels=512, has_bias=True>
    (1): ReLU<>
    (2): Dense<input_channels=512, output_channels=512, has_bias=True>
    (3): ReLU<>
    (4): Dense<input_channels=512, output_channels=10, has_bias=True>
    >
  >


Layer: dense_relu_sequential.0.weight
Size: (512, 784)
Values : [[-0.01270912 -0.00553937 -0.00622345 ...  0.00974897  0.00378853
  -0.00879488]
 [ 0.00454485  0.00105424  0.02829224 ... -0.00480925  0.00859034
  -0.0075234 ]] 

Layer: dense_relu_sequential.0.bias
Size: (512,)
Values : [0. 0.] 

Layer: dense_relu_sequential.2.weight
Size: (512, 512)
Values : [[-0.00303549  0.02051559  0.03005496 ...  0.00813595 -0.02086384
  -0.00501902]
 [ 0.00523915  0.00595684 -0.02108657 ... -0.00816013  0.00160791
  -0.00521205]] 

Layer: dense_relu_sequential.2.bias
Size: (512,)
Values : [0. 0.] 

Layer: dense_relu_sequential.4.weight
Size: (10, 512)
Values : [[ 0.00922695 -0.00915574 -0.01120596 ... -0.00575607 -0.00918559
  -0.00985601]
 [-0.01312499  0.01030371  0.01826839 ...  0.00239934 -0.01605123
  -0.0015749 ]] 

Layer: dense_relu_sequential.4.bias
Size: (10,)
Values : [0. 0.] 

更多内置神经网络层详见mindspore.nn API。
网络构建

2. 小结

今天主要学习了昇思神经网络的构建,包括模型定义、使用模型预测,深入学习了模型展平层、全连接层、激活函数、有序神经网络、概率分布函数和模型参数的相关知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1897746.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MobaXterm不显示隐藏文件

MobaXterm在左边显示隐藏文件&#xff0c;以.开头的文件&#xff0c;想让它不显示&#xff0c;点击红框按钮就可以了

计算机视觉——opencv快速入门(二) 图像的基本操作

前言 上一篇文章中我们介绍了如何配置opencv&#xff0c;而在这篇文章我们主要介绍的是如何使用opencv来是实现一些常见的图像操作。 图像的读取&#xff0c;显示与存储 读取图像文件 在opencv中我们利用imread函数来读取图像文件,函数语法如下&#xff1a; imagecv2.imre…

Python 可视化 web 神器:streamlit、Gradio、dash、nicegui;低代码 Python Web 框架:PyWebIO

官网&#xff1a;https://streamlit.io/ github&#xff1a;https://github.com/streamlit/streamlit API 参考&#xff1a;https://docs.streamlit.io/library/api-reference 最全 Streamlit 教程&#xff1a;https://juejin.cn/column/7265946243196436520 Streamlit-中文文档…

如何在 Ubuntu上搭建 LAMP

远程登录 Ubuntu系统环境 ssh (User)(IP) # 比如&#xff1a;ssh lennlouis192.168.207.128 为安全起见&#xff0c;建议你使用 root 登录 VPS 后创建一个具有 sudo 权限的帐号。 安装和配置 Apache 2 Apache Http Server 是一个开源的&#xff0c;非常流行&#xff0c;使用…

直播预告 | VMware大规模迁移实战,HyperMotion助力业务高效迁移

2006年核高基专项启动&#xff0c;2022年国家79号文件要求2027年央国企100%完成信创改造……国家一系列信创改造政策的推动&#xff0c;让服务器虚拟化软件巨头VMware在中国的市场份额迅速缩水。 加之VMware永久授权的取消和部分软件组件销售策略的变更&#xff0c;导致VMware…

QoS-优先级映射

拓扑图 配置 先完成此配置复杂流分类-CSDN博客 配置qos map-table 接口开启信任DSCP qos map-table dscp-lpinput 32 output 5input 46 output 6 # interface GigabitEthernet0/0/0trust dscp override # AR1上10.1.1.1 ping 3.3.3.3&#xff0c;该流量标记为EF EF映射为p…

vs 远程链接ssh 开发 简单实验

1.概要 动态编译语言&#xff0c;跨平台必须做分别的编译&#xff0c;比如linux和windows。如何再windows环境下开发编译出linux平台的程序呢&#xff0c;vs支持远程链接编辑&#xff0c;就是再vs中写代码&#xff0c;但是编译确是链接远程的环境编译的。 2.环境准备 2.1 vs…

DataWhale-吃瓜教程学习笔记 (七)

学习视频**&#xff1a;第6章-支持向量机_哔哩哔哩_bilibili 西瓜书对应章节&#xff1a; 第六章 支持向量机 - 算法原理 几何角度 对于线性可分数据集&#xff0c;找距离正负样本距离都最远的超平面&#xff0c;解是唯一的&#xff0c;泛化性能较好 - 超平面 - 几何间隔 例…

如何让自动化测试更加灵活简洁?

简化的架构对于自动化测试和主代码一样重要。冗余和不灵活性可能会导致一些问题&#xff1a;比如 UI 中的任何更改都需要更新多个文件&#xff0c;测试可能在功能上相互重复&#xff0c;并且支持新功能可能会变成一项耗时且有挑战性的工作来适应现有测试。 页面对象模式如何理…

前后端数据交互流程

一、前言 用户在浏览器访问一个网站时&#xff0c;会有前后端数据交互的过程&#xff0c;前后端数据交互也有几种的情况&#xff0c;一下就简单的来说明一下 二、原理 介绍前后端交互前先来了解一下浏览器的功能&#xff0c;浏览器通过渲染引擎和 JavaScript 引擎协同工作&am…

东北财税之星:董女士的家乡创业记

乐财业智慧财税赋能平台&#xff0c;是一个帮助财税机构专业提升、业务增长&#xff0c;让财税生意更好做的综合赋能平台。聚焦财税公司业绩增长&#xff0c;预计2027年帮助2000家财税合伙人利润增长300%&#xff0c;致力打造轻量化、批量化、智能化的”业财税“一体财税服务生…

【C++航海王:追寻罗杰的编程之路】关联式容器的底层结构——AVL树

目录 1 -> 底层结构 2 -> AVL树 2.1 -> AVL树的概念 2.2 -> AVL树节点的定义 2.3 -> AVL树的插入 2.4 -> AVL树的旋转 2.5 -> AVL树的验证 2.6 -> AVL树的性能 1 -> 底层结构 在上文中对对map/multimap/set/multiset进行了简单的介绍&…

从CPU的视角看C++的构造函数和this指针

从汇编角度&#xff0c;清晰的去看构造函数和this指针到底是个什么东西呢&#xff1f;也许可以解决你的一点小疑问 首先写一个很简单的代码demo&#xff1a; class A{ public:int a;A(){;}void seta(int _a){a_a;}A* getA(){return this;} };int fun1(int px){return px; }in…

烟台LP-SCADA系统如何实现实时监控和过程控制?

关键字:LP-SCADA系统, 传感器可视化, 设备可视化, 独立SPC系统, 智能仪表系统,SPC可视化,独立SPC系统 LP-SCADA&#xff08;监控控制与数据采集&#xff09;系统实现实时监控和过程控制的主要原理和组件如下&#xff1a; 数据采集&#xff1a;LP-SCADA系统通过部署在现场的传…

跨境人最怕的封店要怎么规避?

跨境人最怕的是什么&#xff1f;——封店 造成封店的原因很多&#xff0c;IP关联、无版权售卖、虚假发货等等&#xff0c;其中IP关联这个问题导致店铺被封在跨境商家中简直是屡见不鲜 IP关联&#xff0c;是指被海外平台检测到多家店铺开设在同一个站点上的情况。我们知道有些…

用html+css设计一个列表清单小卡片

目录 简介: 效果图: 源代码: 可能的问题: 简介: 这个HTML代码片段是一个简单的列表清单设计。它包含一个卡片元素(class为"card"),内部包含一个无序列表(ul),列表项(li)前面有一个特殊的符号(△)。整个卡片元素设计成300px宽,150px高,具有圆角边…

Linux-C语言实现一个进度条小项目

如何在linux中用C语言写一个项目来实现进度条&#xff1f;&#xff08;如下图所示&#xff09; 我们知道\r是回车&#xff0c;\n是换行&#xff08;且会刷新&#xff09; 我们可以用 \r 将光标移回行首&#xff0c;重新打印一样格式的内容&#xff0c;覆盖旧的内容&#xff0c;…

lua入门(2) - 数据类型

前言 本文参考自: Lua 数据类型 | 菜鸟教程 (runoob.com) 希望详细了解的小伙伴还请查看上方链接: 八个基本类型 type - 函数查看数据类型: 测试程序: print(type("Hello world")) --> string print(type(10.4*3)) --> number print(t…

opencv概念以及安装方法

#opencv相关概念介绍 Open Source Computer Vision Library 缩写 opencv 翻译&#xff1a;开源的计算机视觉库 &#xff0c;英特尔公司发起并开发&#xff0c;支持多种编程语言&#xff08;如C、Python、Java等&#xff09;&#xff0c;支持计算机视觉和机器学习等众多算法&a…

ONNXRuntime与CUDA所对应的版本

官方链接&#xff1a; NVIDIA - CUDA | onnxruntime