昇思25天学习打卡营第5天|网络构建

news2025/1/13 10:23:43

 一、简介:

神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类(这个类和pytorch中的modul类是一样的作用),也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子Cell构成。使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。

二、环境准备:

import mindspore
import time
from mindspore import nn, ops

没有下载mindspore的宝子,还是回看我的昇思25天学习打卡营第1天|快速入门-CSDN博客,先下载好再进行下面的操作。

三、神经网络搭建:

1、定义模型类:

我们首先要继承nn.Cell类,并再__init__方法中进行子Cell的实例化和管理,并再construct方法(和pytorch中的forward方法一致)中实现前向计算:

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

# 实例化并打印
model = Network()
print(model)
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

 ①self.flatten = nn.Flatten():创建一个Flatten层,并将其作为类的属性。Flatten层的作用是将输入的数据“压平”,即不管输入数据的原始形状如何,输出都将是沿着特定维度的连续数组。

② self.dense_relu_sequential = nn.SequentialCell(...):创建一个SequentialCell,它是一种特殊的Cell,可以顺序地执行其中包含的多个层。这个SequentialCell包含了三个全连接层(Dense),每个全连接层后面跟着一个ReLU激活函数层,除了最后一个全连接层:

  • 第一个nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"):这是一个全连接层,它接受28*28=784个输入,并产生512个输出。权重(weight_init)和偏置(bias_init)分别使用正态分布和零值进行初始化。

  • nn.ReLU():ReLU激活函数,其数学表达式为f(x) = max(0, x),即负值输出为零,正值保持不变。

  • 接下来的两个nn.Dense与对应的nn.ReLU层与第一个类似,它们分别接收512个输入并再次输出512个值,以及最终输出10个值,这可能对应于10个类别。

我们构造一个数据,并使用softmax预测其概率:

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
print(logits)

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

2、模型层详解:

(1)nn.Flatten:

 nn.Flantten方法用于将输入数据“压平”,以便后续处理:

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

(2)nn.Dense:

 nn.Dense层作为全连接层,用于对输入的数据进行线性变换和处理:

layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

(3)nn.Relu:

nn.Relu是本次实验中使用的激活函数,用于对神经网络的权重进行处理,以缓解欠拟合和过拟合的发生,常见的激活函数处了Relu,还有:Sigmoid, Tanh等:

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

 (4)nn.SequentialCell:

nn.SequentialCell和pytorch中的nn.Sequential的作用一样,用于存放dense全连接层和激活函数层的组合,以方便在前向计算中使用:

seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

logits = seq_modules(input_image)
print(logits.shape)

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

(5)nn.Softmax:

 nn.softmax方法将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)
print(pred_probab)
# argmax函数返回指定维度上最大值的索引
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

3、模型参数:

网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

print(f"Model structure: {model}\n\n")

for name, param in model.parameters_and_names():
    print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")
    
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1854087.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LVGL8.3动画图像(太空人)

LVGL8.3 动画图像 1. 动画图像本质 我们知道电影属于视频,而电影的本质是将一系列动作的静态图像进行快速切换而呈现出动画的形式,也就是说动画本质是一系列照片。所以 lvgl 依照这样的思想而定义了动画图像,所以在 lvgl 中动画图像类似于普…

【学习笔记】Mybatis-Plus(三):MP中Wrapper的使用

Wrapper简介 注意: 查询用QueryWrapper和LambdaQueryWrapper来封装 updateWrapper和LambdaUPdateWrapper不但能封装查询还能更改要更新的对象。 QueryWrapper的使用 QueryWrapper中的很多条件限定都是见名知其意的。下表列出来几个常用的: 1.多条件进行…

【八】【QT开发应用】QTcreate项目打包成.exe文件或.apk文件,EnigmaVirtualBox软件下载,虚拟网站代打开QT应用

EnigmaVirtualBox下载 Enigma Virtual Box QTcreate项目打包成.exe可执行文件 找到自己写好的项目的.exe文件 将这个文件复制到一个新的文件夹里面 在这个新的文件夹里面打开cmd,这样可以使得cmd直接进入到该文件夹 打包.exe命令行 输入下面的命令行 windeployqt game…

EndNote 21 for Mac v21.3 文献管理软件安装

Mac分享吧 文章目录 效果一、下载软件二、开始安装1、双击运行安装EndNote212、升级 三、运行1、打开软件,测试 安装完成!!! 效果 一、下载软件 下载软件 链接:http://www.macfxb.cn 二、开始安装 1、双击运行安装End…

【目标检测】DAB-DETR

一、引言 论文: DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR 作者: IDEA 代码: DAB-DETR 注意: 该算法是对DETR的改进,在学习该算法前,建议掌握多头注意力、Sinusoidal位置编码、DETR等相…

一款基于WordPress开发的高颜值的自适应主题Puock

主题特性 支持白天与暗黑模式 全局无刷新加载 支持博客与CMS布局 内置WP优化策略 一键全站变灰 网页压缩成一行 后台防恶意登录 内置出色的SEO功能 评论Ajax加载 文章点赞、打赏 支持Twemoji集成 支持QQ登录 丰富的广告位 丰富的小工具 自动百度链接提交 众多页面模板 支持评论…

富文本编辑器CKEditor

介绍 富文本编辑器不同于文本编辑器,它提供类似于 Microsoft Word 的编辑功能 在Django中,有可以现成的富文本三方模块django-ckeditor,具体安排方式: pip install django-ckeditor==6.5.1官网:Django CKEditor — Django CKEditor 6.7.0 documentation 使用方式 创建项…

torchinfo这个包中的summary真的很好用

1.安装直接使用 pip 进行安装即可: pip install torchinfo 2.导入该模块 from torchinfo import summary 3.使用模块 summary(model)#这里的model是你自己的model,可以添加参数进去 4.效果图: 第一个图片是直接打印model吗,…

「动态规划」如何求环绕字符串中唯一的子字符串个数?

467. 环绕字符串中唯一的子字符串https://leetcode.cn/problems/unique-substrings-in-wraparound-string/description/ 定义字符串base为一个"abcdefghijklmnopqrstuvwxyz"无限环绕的字符串,所以base看起来是这样的:"...zabcdefghijklm…

华硕笔记本重装系统详细操作,图文教程体验Win11如何重装系统

随着科技的不断发展,电脑操作系统的步骤也在不断更新迭代。对于华硕笔记本用户来说,升级到Windows 11操作系统可以带来更好的使用体验。本文将通过图文教程的形式,详细介绍华硕笔记本重装Windows 11系统的操作步骤,帮助用户顺利完…

2-14 基于matlab的GA优化算法优化车间调度问题

基于matlab的GA优化算法优化车间调度问题。n个工作在m个台机器上加工。已知每个工作中工序加工顺序、各工序的加工时间以及每个工件所包含的工序,在满足约束条件的前提下,目的是确定机器上各工件顺序,以保证某项性能指标最优。程序功能说明&a…

Selenium进行Web自动化测试

Selenium进行Web自动化测试 SeleniumPython实现Web自动化测试一、环境配置 SeleniumPython实现Web自动化测试 一、环境配置 环境基于win10(X64) 安装Python;安装PyCham安装chomedriver chomedriver下载地址 可以查看本地chrome软件版本下载…

cesium 添加 Echarts 饼图

cesium 添加 Echarts 饼图 1、实现思路 1、首先创建echarts饼图,拿到创建好的canvas 2、用echarts里面生成的canvas添加到cesium billboard中 2、示例代码 <!DOCTYPE html> <html lang="en"><head><

【database2】redis:优化/备份/订阅

文章目录 1.redis安装&#xff1a;加载.conf2.操作&#xff1a;set/get&#xff0c;push/pop&#xff0c;add/rem3.Jedis&#xff1a;java程序连接redis&#xff0c;拿到jedis4.案例_好友列表&#xff1a;json om.4.1 前端&#xff1a;index.html4.2 web&#xff1a;FriendSer…

oracle发送http请求

UTL_HTTP包让SQL和PLSQL能够调用超文本传输协议&#xff08;HTTP&#xff09;&#xff0c;也就是说可以使用它在Internet上访问数据。 当包用HTTPS从Web site获取数据时&#xff0c;要使用Oracle Wallet&#xff0c;它是由Oracle Wallet Manager或者orapki utility创建。非HTT…

双jdk切换

现在因为业务需求单一jdk8已经不满足日常需求了,以我为例之前用的jdk8,但是最新的一个项目用的是17版本的,没招了就下载配置的一套,需要手动切换用哪个版本的步骤如下 jdk8就自己安装配置吧,这只说在有8的版本上在配置17 1.下载一个17win的包(不下载exe) Java Downloads | O…

git 初基本使用-----------笔记

Git命令 下载git 打开Git官网&#xff08;git-scm.com&#xff09;&#xff0c;根据自己电脑的操作系统选择相应的Git版本&#xff0c;点击“Download”。 基本的git命令使用 可以在项目文件下右击“Git Bash Here” &#xff0c;也可以命令终端下cd到指定目录执行初始化命令…

pytorch实现的面部表情识别

一、绪论 1.1 研究背景 面部表情识别 (Facial Expression Recognition ) 在日常工作和生活中&#xff0c;人们情感的表达方式主要有&#xff1a;语言、声音、肢体行为&#xff08;如手势&#xff09;、以及面部表情等。在这些行为方式中&#xff0c;面部表情所携带的表达人类…

vue-cli搭建

一、vue-cli是什么&#xff1f; vue-cli 官方提供的一个脚手架&#xff0c;用于快速生成一个 vue 的项目模板&#xff1b;预先定义 好的目录结构及基础代码&#xff0c;就好比咱们在创建 Maven 项目时可以选择创建一个 骨架项目&#xff0c;这个骨架项目就是脚手架&#xff0c;…