昇思MindSpore学习总结五——网络构建

news2025/1/22 23:53:55

1、网络构建

        神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子Cell构成。使用这        样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。

mindspore.nn.Cell类——mindspore.nn.Cell(auto_prefix=Trueflags=None)

MindSpore中神经网络的基本构成单元。模型或神经网络层应当继承该基类。

mindspore.nn 中神经网络层也是Cell的子类,如 mindspore.nn.Conv2d 、 mindspore.nn.ReLU 等。Cell在GRAPH_MODE(静态图模式)下将编译为一张计算图,在PYNATIVE_MODE(动态图模式)下作为神经网络的基础模块。

【参数】

  • auto_prefix (bool,可选) - 是否自动为Cell及其子Cell生成NameSpace。该参数同时会影响 Cell 中权重参数的名称。如果设置为 True ,则自动给权重参数的名称添加前缀,否则不添加前缀。通常情况下,骨干网络应设置为 True ,否则会产生重名问题。用于训练骨干网络的优化器、 mindspore.nn.TrainOneStepCell 等,应设置为 False ,否则骨干网络的权重参数名会被误改。默认值: True 。

  • flags (dict,可选) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值: None 。

下面构建一个用于Mnist数据集分类的神经网络模型 。

2、定义模型

        当我们定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。

import mindspore
from mindspore import nn, ops

construct意为神经网络(计算图)构建。 

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

 构建完成后,实例化Network对象,并查看其结构。

model = Network()
print(model)

         构造一个输入数据,直接调用模型,可以获得一个十维的Tensor输出,其包含每个类别的原始预测值。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits

 通过一个nn.Softmax层实例来获得预测概率。

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

 3、模型层

        分解上节构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

 3.1 nn.Flatten层

mindspore.nn.Flatten(start_dim=1end_dim=- 1)

沿着从 start_dim 到 end_dim 的维度,对输入Tensor进行展平。

【参数】

  • start_dim (int, 可选) - 要展平的第一个维度。默认值: 1 。

  • end_dim (int, 可选) - 要展平的最后一个维度。默认值: -1 。

输入——>x (Tensor) - 要展平的输入Tensor,输出——>Tensor。如果没有维度被展平,返回原始的输入 input,否则返回展平后的Tensor。如果 input 是零维Tensor,将会返回一个一维Tensor。

将28x28的2D张量转换为784大小的连续数组。

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)

 3.2 nn.Dense层

mindspore.nn.Dense(in_channelsout_channelsweight_init=Nonebias_init=Nonehas_bias=Trueactivation=Nonedtype=mstype.float32)

全连接层。使用权重和偏差对输入进行线性变换。

适用于输入的密集连接层。公式如下:

outputs=activation(X∗kernel+bias)

其中 𝑋 是输入Tensor, activation 是激活函数, kernel 是一个权重矩阵,其数据类型与 𝑋 相同, bias 是一个偏置向量,其数据类型与 𝑋 相同(仅当has_bias为True时)。

【参数】

  • in_channels (int) - Dense层输入Tensor的空间维度。

  • out_channels (int) - Dense层输出Tensor的空间维度。

  • weight_init (Union[Tensor, str, Initializer, numbers.Number]) - 权重参数的初始化方法。数据类型与 x 相同。str的值引用自函数 initializer。默认值:None ,权重使用HeUniform初始化。

  • bias_init (Union[Tensor, str, Initializer, numbers.Number]) - 偏置参数的初始化方法。数据类型与 x 相同。str的值引用自函数 initializer。默认值:None ,偏差使用Uniform初始化。

  • has_bias (bool) - 是否使用偏置向量 bias 。默认值: True 。

  • activation (Union[str, Cell, Primitive, None]) - 应用于全连接层输出的激活函数。可指定激活函数名,如’relu’,或具体激活函数,如 mindspore.nn.ReLU 。默认值: None 。

  • dtype (mindspore.dtype) - Parameter的数据类型。默认值: mstype.float32 。

输入——>x (Tensor) - shape为 (∗,𝑖𝑛_𝑐ℎ𝑎𝑛𝑛𝑒𝑙𝑠) 的Tensor。参数中的 in_channels 应等于输入中的 𝑖𝑛_𝑐ℎ𝑎𝑛𝑛𝑒𝑙𝑠 

输出——>shape为 (∗,𝑜𝑢𝑡_𝑐ℎ𝑎𝑛𝑛𝑒𝑙𝑠) 的Tensor。

layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)

 3.3 nn.ReLU

         mindspore.nn.ReLU

逐元素计算ReLU(Rectified Linear Unit activation function)修正线性单元激活函数。

ReLU(𝑖𝑛𝑝𝑢𝑡)=(𝑖𝑛𝑝𝑢𝑡)+=max(0,𝑖𝑛𝑝𝑢𝑡),

逐元素求 max(0,𝑖𝑛𝑝𝑢𝑡) 。

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

3.4 nn.SequentialCell 

        mindspore.nn.SequentialCell(*args)

构造Cell顺序容器。关于Cell的介绍,可参考 Cell。

SequentialCell将按照传入List的顺序依次将Cell添加。此外,也支持OrderedDict作为构造器传入。

【参数】

  • args (list, OrderedDict) - 仅包含Cell子类的列表或有序字典。

seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

logits = seq_modules(input_image)
print(logits.shape)

 3.5 nn.Softmax

        mindspore.nn.Softmax(axis=- 1)

逐元素计算Softmax激活函数,它是二分类函数 mindspore.nn.Sigmoid 在多分类上的推广,目的是将多分类的结果以概率的形式展现出来。

对输入Tensor在轴 axis 上的元素计算其指数函数值,然后归一化到[0, 1]范围,总和为1。

Softmax定义为:

其中, 𝑖𝑛𝑝𝑢𝑡𝑖 是输入Tensor在轴 axis 上的第 𝑖 个元素。

【参数】

  • axis (int,可选) - 指定Softmax运算的轴axis,假设输入 input 的维度为input.ndim,则 axis 的范围为 [-input.ndim, input.ndim) ,-1表示最后一个维度。默认值: -1 。

        最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)

4、模型参数

         网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

print(f"Model structure: {model}\n\n")

for name, param in model.parameters_and_names():
    print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1877299.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

创新探析:我国AIGC产业规模有望在2030年破万亿,创意设计行业或迎来全面革新

在科技日新月异的今天,人工智能生成内容(AIGC)与创意设计行业的结合正以前所未有的速度推动着产业变革。随着技术的不断突破和市场需求的日益增长,我国AIGC产业规模有望在2030年突破万亿元大关,这一宏伟目标不仅是对技…

VUE-CLI脚手架项目的初步创建与配置

首先创建一个VUE项目,注意选择版本为 2.6.10 打开APP.vue文件,并且删除APP.vue中多余的代码 创建index.vue文件 在此文件中写入如下图片中的代码来初步创建页面 创建router目录,并且创建index.js 文件如下 在终端输入npm run serve 运行 然后…

代码随想录-Day43

52. 携带研究材料(第七期模拟笔试) 小明是一位科学家,他需要参加一场重要的国际科学大会,以展示自己的最新研究成果。他需要带一些研究材料,但是他的行李箱空间有限。这些研究材料包括实验设备、文献资料和实验样本等…

[OtterCTF 2018]Recovery

里克必须找回他的文件!用于加密文件的随机密码是什么 恢复他的文件 ,感染的文件 ? vmware-tray.ex 前面导出的3720.dmp 查找一下 搜索主机 strings -e l 3720.dmp | grep “WIN-LO6FAF3DTFE” 主机名 后面跟着一串 代码 aDOBofVYUNVnmp7 是不…

c++类和对象(三)日期类

类和对象 一.拷贝构造函数定义二.拷贝构造函数特征三.const成员函数权限权限的缩小权限的缩放大 四.隐式类型转换 一.拷贝构造函数定义 拷贝构造函数:只有单个形参,该形参是对本类类型对象的引用(一般常用const修饰),在用已存 在的类类型对象…

韩顺平0基础学Java——第33天

p653-674 坦克大战 继续上回游戏 将每个敌人的信息,恢复成Node对象,放进Vector里面。 播放音乐 使用一个播放音乐的类。 第二阶段结束了 网络编程 相关概念 (权当是复习计网了) 网络 1.概念:两台或多台设备通过一定物理设备连…

Java基础知识-集合类

1、HashMap 和 Hashtable 的区别? HashMap 和 Hashtable是Map接口的实现类,它们大体有一下几个区别: 1. 继承的父类不同。HashMap是继承自AbstractMap类,而HashTable是继承自Dictionary类。 2. 线程安全性不同。Hashtable 中的方…

【Nginx】源码安装

nginx官网:nginx: download 选择文档版本安装即可 1.安装依赖包 //一键安装上面四个依赖 yum -y install gcc zlib zlib-devel pcre-devel openssl openssl-devel 2.下载并解压安装包 //创建一个文件夹 cd /usr/local mkdir nginx cd nginx //将下载的nginx压缩…

蓝卓出席“2024C?O大会”,探讨智能工厂建设新路径

6月29日,“2024C?O大会”在金华成功举办。此次大会由浙江省企业信息化促进会主办,与以往CIO峰会不同,“C?O”代表了企业数字化中的核心决策者群体,包括传统的CIO、CEO、CDO等。 本次大会围绕C?O、AIGC与制造业、数据价值、未来…

MySQL之可扩展性(九)

可扩展性 直接连接 2.修改应用的配置 还有一个分发负载的办法是重新配置应用。例如,你可以配置多个机器来分担生成大报表操作的负载。每台机器可以配置成连接到不同的MySQL备库,并为第N个用户或网站生成报表。 这样的系统很容易实现,但如果…

第7章_低成本 Modbus 传感器的实现

文章目录 第7章 低成本 Modbus 传感器的实现7.1 硬件资源介绍与接线7.2 开发环境搭建7.3 创建与体验第 1 个工程7.3.1 创建工程7.3.2 配置调试器7.3.3 配置 GPIO 操作 LED 7.4 UART 编程7.4.1 使用 STM32CubeMX 进行配置1.UART12.配置 RS485方向引脚 7.4.2 封装 UART7.4.3 上机…

memory动态内存管理学习之weak_ptr

此头文件是动态内存管理库的一部分。std::weak_ptr 是一种智能指针,它持有对被 std::shared_ptr 管理的对象的非拥有性(“弱”)引用。在访问所引用的对象前必须先转换为 std::shared_ptr。std::weak_ptr 用来表达临时所有权的概念&#xff1a…

快速应用开发(RAD):加速软件开发的关键方法

目录 前言1. 快速应用开发的概念1.1 什么是快速应用开发?1.2 RAD与传统开发方法的对比 2. 快速应用开发的实施步骤2.1 需求分析与规划2.2 快速原型开发2.3 用户评估与反馈2.4 迭代开发与改进2.5 最终交付与维护 3. 快速应用开发的优点与应用场景3.1 优点3.2 应用场景…

24级中国科学技术大学843信号与系统考研分数线,中科大843初复试科目,参考书,大纲,真题,苏医工生医电子信息与通信工程。

(上岸难度:★★★★☆,考试大纲、真题、经验帖等考研资讯和资源加群960507167/博睿泽电子信息通信考研咨询:34342183) 一、专业目录及考情分析 说明: ①复试成绩:满分100分。上机满分50分,面试满分150分,复试成绩(上机…

Llama 3 模型微调的步骤

环境准备 操作系统:Ubuntu 22.04.5 LTS Anaconda3:Miniconda3-latest-Linux-x86_64 GPU: NVIDIA GeForce RTX 4090 24GStep 1. 准备conda环境 创建一个新的conda环境: conda create --name llama_factory python3.11激活刚刚创…

线性代数--行列式1

本篇来自对线性代数第一篇的行列式的一个总结。 主要是行列式中有些关键点和注意事项,便于之后的考研复习使用。 首先,对于普通的二阶和三阶行列式,我们可以直接对其进行拆开,展开。 而对于n阶行列式 其行列式的值等于它的任意…

系统运维面试题总结(网络基础类)

系统运维面试题总结(网络基础类) 网络基础类第七层:应用层第六层:表示层第五层:会话层第四层:传输层第三层:网络层第二层:数据链路层第一层:物理层 类似面试题1、TCP/IP四…

Django 配置静态文件

1,DebugTrue 调试模式 Test/Test/settings.py DEBUG True...STATICFILES_DIRS [os.path.join(BASE_DIR, static),] STATIC_URL /static/ 1.1 创建静态文件 Test/static/6/images/Sni1.png 1.2 添加视图函数 Test/app6/views.py from django.shortcuts impor…

使用Java实现通用树形结构转换工具类:深入解析TreeUtil和TreeNode接口

文章目录 一、TreeNode接口设计二、TreeUtil工具类设计三、示例:实现TreeNode接口的节点类四、示例:使用TreeUtil构建树形结构五、总结 🎉欢迎来到Java学习路线专栏~探索Java中的静态变量与实例变量 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#x1…

落石滑坡监测报警系统:创新保障高速公路安全

​ ​​在现代交通建设中,高速公路的安全性和稳定性至关重要。特别是易发生落石区域,如何有效预防和应对落石滑坡带来的事故成为了一项关键性挑战。为此,落石滑坡监测报警系统应运而生,它通过先进的技术手段,为高速…