神经网络基础-神经网络搭建和参数计算

news2024/12/18 14:00:08

文章目录

    • 1.构建神经网络
    • 2. 神经网络的优缺点

1.构建神经网络

在 pytorch 中定义深度神经网络其实就是层堆叠的过程,继承自nn.Module,实现两个方法:

  • __init__方法中定义网络中的层结构,主要是全连接层,并进行初始化。
  • forward方法,在实例化模型的时候,底层会自动调用该函数。该函数中可以定义学习率,为初始化定义的layer传入数据等。

我们来构建如下图所示的神经网络模型:
在这里插入图片描述

编码设计如下:

  1. 第1个隐藏层:权重初始化采用标准化的xavier初始化 激活函数使用sigmoid。
  2. 第2个隐藏层:权重初始化采用标准化的He初始化 激活函数采用relu。
  3. out输出层线性层 假若二分类,采用softmax做数据归一化。
# 创建神经网络
import torch
import torch.nn as nn
# pip install torchsummary
from torchsummary import summary # 计算模型参数,查看模型结构 pip install torchsummary
# 创建神经网络模型类
class Model(nn.Module):
    # 初始化属性值
    def __init__(self):
        # 调用父类的初始化属性值
        super(Model, self).__init__()
        # 创建第一个隐藏层模型,3个输入特征,3个输出特征
        self.linear1 = nn.Linear(3, 3)
        # 初始化权重 xavier 均匀分布初始化
        nn.init.xavier_uniform_(self.linear1.weight)
        # 创建第二个隐藏层,3个输入特征(上一层的输出特征),2个输出特征
        self.linear2 = nn.Linear(3, 2)
        # 初始化权重 kaiming 正太分布初始化
        nn.init.kaiming_normal_(self.linear2.weight)
        # 创建输出层模型
        self.out = nn.Linear(2, 2)
    # 创建向前传播方法,自动执行 forward()方法
    def forward(self, x):
        # 数据经过第一个线性层
        x = self.linear1(x)
        # 使用 sigmoid 激活函数
        x = torch.sigmoid(x)
        # 数据经过第二个线性层
        x = self.linear2(x)
        # 使用 relu 激活函数
        x = torch.relu(x)
        # 数据经过输出层
        x = self.out(x)
        # 使用 softmax 激活函数
        # dim=-1:每一维度行数据相机为1
        x = torch.softmax(x, dim=-1)
        return x

if __name__ == '__main__':
    # 实例化model对象
    model = Model()
    # 随机产生数据
    data = torch.randn(5,3)
    print('data.shape',data.shape)
    # 数据经过神经网络模型训练
    out = model(data)
    print('out.shape',out.shape)
    # 计算模型参数
    # 计算每层每个神经元的 w 和 b 个数总和
    summary(model,input_size=(3,),batch_size=5)
    # 查看模型参数
    print("======查看模型参数w和b======")
    for name, param in model.named_parameters():
        print(name, param)
  • 神经网络的输入数据是为[batch_size, in_features]的张量经过网络处理后获取了[batch_size, out_features]的输出张量。

  • 在上述例子中,batch_size=5, in_features=3,out_features=2,结果如下所示:

    data.shape torch.Size([5, 3])
    out.shape torch.Size([5, 2])
    

    模型参数输出:

    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Linear-1                     [5, 3]              12
                Linear-2                     [5, 2]               8
                Linear-3                     [5, 2]               6
    ================================================================
    Total params: 26
    Trainable params: 26
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.00
    Forward/backward pass size (MB): 0.00
    Params size (MB): 0.00
    Estimated Total Size (MB): 0.00
    ----------------------------------------------------------------
    ======查看模型参数w和b======
    linear1.weight Parameter containing:
    tensor([[ 0.3857,  0.4809, -0.0346],
            [ 0.3645,  0.2803, -0.6291],
            [ 0.1999, -0.6617,  0.7724]], requires_grad=True)
    linear1.bias Parameter containing:
    tensor([0.3084, 0.5636, 0.4501], requires_grad=True)
    linear2.weight Parameter containing:
    tensor([[ 0.1063,  0.7494,  0.4311],
            [-1.4152,  0.3396, -0.8590]], requires_grad=True)
    linear2.bias Parameter containing:
    tensor([-0.3771,  0.2937], requires_grad=True)
    out.weight Parameter containing:
    tensor([[-0.6012,  0.4727],
            [-0.2953, -0.5854]], requires_grad=True)
    out.bias Parameter containing:
    tensor([-0.3271,  0.4940], requires_grad=True)
    

模型参数的计算:

  1. 以第一个隐层为例:该隐层有3个神经元,每个神经元的参数为:4个(w1,w2,w3,b1),所以一共用3x4=12个参数。
  2. 输入数据和网络权重是两个不同的事儿!对于初学者理解这一点十分重要,要分得清。
    在这里插入图片描述

2. 神经网络的优缺点

  1. 优点
    ➢ 精度高,性能优于其他的机器学习算法,甚至在某些领域超过了人类。
    ➢ 可以近似任意的非线性函数。
    ➢ 近年来在学界和业界受到了热捧,有大量的框架和库可供调。
  2. 缺点
    ➢ 黑箱,很难解释模型是怎么工作的。
    ➢ 训练时间长,需要大量的计算资源。
    ➢ 网络结构复杂,需要调整超参数。
    ➢ 部分数据集上表现不佳,容易发生过拟合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2261624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

web网页前后端交互方式

参考该文&#xff0c; 一、前端通过表单<form>向后端发送数据 前端是通过html中的<form>表单&#xff0c;设置method属性定义发送表单数据的方式是get还是post。 如使用get方式&#xff0c;则提交的数据会在url中显示&#xff1b;如使用post方式&#xff0c;提交…

Mac配置 Node镜像源的时候报错解决办法

在Mac电脑中配置国内镜像源的时候报错,提示权限问题,无法写入配置文件。本文提供解决方法,青测有效。 一、原因分析 遇到的错误是由于 .npm 目录下的文件被 root 用户所拥有,导致当前用户无法写入相关配置文件。 二、解决办法 在终端输入以下命令,输入管理员密码即可。 su…

Linux实操篇-远程登录/Vim/开机重启

目录 传送门前言一、远程登录1、概念2、ifconfig3、实战3.1、SSH&#xff08;Secure Shell&#xff09;3.2、VNC&#xff08;Virtual Network Computing&#xff09;3.3、RDP&#xff08;Remote Desktop Protocol&#xff09;3.4、Telnet&#xff08;不推荐&#xff09;3.5、FT…

【C/C++进阶】CMake学习笔记

本篇文章包含的内容 一、CMake简介二、使用CMake构建工程2.1 一个最简单的CMake脚本2.2 使用变量和宏2.3 文件搜索 三、使用CMake制作和使用库文件3.1 静态库和动态库3.2 字符串操作3.3 CMake制作库文件3.4 CMake使用库文件3.4.1 使用link_libraries链接3.4.2 使用target_link_…

JS 生成防篡改水印

网页中有水印的需求&#xff0c;今天我们实现手写一个防篡改水印&#xff0c;先看下效果图&#xff1a; 一、创建class函数 传递一个dom为水印包裹器&#xff0c;有一些监听防篡改的observer&#xff0c;然后实例化的时候创建水印&#xff0c;执行create()方法 class WaterMa…

概率论得学习和整理26:EXCEL 关于plot 折线图--频度折线图的一些细节

目录 0 折线图有很多 1 频度折线图 1.1 直接用原始数据做的频度折线图 2 将原始数据生成数据透视表 3 这样可以做出了&#xff0c;频度plot 4 做按某字段汇总&#xff0c;成为累计plot分布 5 修改上面显示效果&#xff0c;做成百分比累计plot频度分布 0 折线图有很多 这…

实现echart大屏动画效果及全屏布局错乱解决方式

如何实现echarts动画效果?如何实现表格或多个垂直布局的柱状图自动滚动效果?如何解决tooltip位置超出屏幕问题,如何解决legend文字过长,布局错乱问题?如何处理饼图的中心图片永远居中? 本文将主要解决以上问题,如有错漏,请指正. 一、大屏动画效果 这里的动画效果主要指&…

pytest入门九:feature

fixture是pytest特有的功能&#xff0c;用以在测试执行前和执行后进行必要的准备和清理工作。使用pytest.fixture标识&#xff0c;定义在函数前面。在你编写测试函数的时候&#xff0c;你可以将此函数名称做为传入参数&#xff0c;pytest将会以依赖注入方式&#xff0c;将该函数…

C# 中的闭包

文章目录 前言一、闭包的基本概念二、匿名函数中的闭包1、定义和使用匿名函数2、匿名函数捕获外部变量3、闭包的生命周期 三、Lambda 表达式中的闭包1、定义和使用 Lambda 表达式2、Lambda 表达式捕获外部变量3、闭包的作用域 四、闭包的应用场景1、事件处理2、异步编程3、迭代…

ChatGPT客户端安装教程(附下载链接)

用惯了各类AI的我们发现每天打开网页还挺不习惯和麻烦&#xff0c;突然发现客户端上架了&#xff0c;懂摸鱼的人都知道这里面的道行有多深&#xff0c;话不多说&#xff0c;开整&#xff01; 以下是ChatGPT客户端的详细安装教程&#xff0c;适用于Windows和Mac系统&#xff1a…

GRE over IPSec 如何应用?如何在ensp上配置GRE over IPSec 实验?

GRE over IPSec应用场景 IPSec VPN本端设备无法感知对端有几个设备 &#xff0c;本端共用一个IPSec SA 。报文封装中没有对端设备的下一跳 &#xff0c;所以无法传输组播、广播和非IP报文 &#xff0c;比如OSPF协议 &#xff0c;导致分支与总部的内部网络之间无法使用OSPF路由…

概率论得学习和整理29: 用EXCEL 描述二项分布

目录 1 关于二项分布的基本内容 2 二项分布的概率 2.1 核心要素 2.2 成功K次的概率&#xff0c;二项分布公式 2.3 期望和方差 2.4 具体试验 2.5 概率质量函数pmf 和cdf 3 二项分布的pmf图的改进 3.1 改进折线图 3.2 如何生成这种竖线图呢 4 不同的二项分布 4.1 p0.…

leetcode 面试经典 150 题:三数之和

链接三数之和题序号11类型数组解题方法排序双指针法难度中等 题目 给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c; 同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三…

【Linux】Nginx一个域名https一个地址配置多个项目【项目实战】

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;CSDN博客专家   &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01…

【线性代数】理解矩阵乘法的意义(点乘)

刚接触线性代数时&#xff0c;很不理解矩阵乘法的计算规则&#xff0c;为什么规则定义的看起来那么有规律却又莫名其妙&#xff0c;现在参考了一些资料&#xff0c;回过头重新总结下个人对矩阵乘法的理解&#xff08;严格来说是点乘&#xff09;。 理解矩阵和矩阵的乘法&#x…

HTML、CSS表格的斜表头样式设置title 画对角线

我里面有用到layui框架的影响&#xff0c;实际根据你自己的框架来小调下就可以 效果如下 上代码 <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-wi…

29. Three.js案例-自定义平面图形

29. Three.js案例-自定义平面图形 实现效果 知识点 WebGLRenderer WebGLRenderer 是 Three.js 中用于渲染 3D 场景的核心类。它利用 WebGL 技术在浏览器中渲染 3D 图形。 构造器 THREE.WebGLRenderer(parameters : object) 参数类型描述parametersobject可选参数对象&…

一条线上的点

给你一个数组 points &#xff0c;其中 points[i] [xi, yi] 表示 X-Y 平面上的一个点。求最多有多少个点在同一条直线上。 提示&#xff1a; 1 < points.length < 300points[i].length 2-104 < xi, yi < 104points 中的所有点 互不相同 解析&#xff1a;使用斜…

WebRTC服务质量(05)- 重传机制(02) NACK判断丢包

WebRTC服务质量&#xff08;01&#xff09;- Qos概述 WebRTC服务质量&#xff08;02&#xff09;- RTP协议 WebRTC服务质量&#xff08;03&#xff09;- RTCP协议 WebRTC服务质量&#xff08;04&#xff09;- 重传机制&#xff08;01) RTX NACK概述 WebRTC服务质量&#xff08;…

八股—Java基础(二)

目录 一. 面向对象 1. 面向对象和面向过程的区别&#xff1f; 2. 面向对象三大特性 3. Java语言是如何实现多态的&#xff1f; 4. 重载&#xff08;Overload&#xff09;和重写&#xff08;Override&#xff09;的区别是什么&#xff1f; 5. 重载的方法能否根据返回值类…