【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算

news2024/11/16 9:46:21

# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包

pip install torchsummary

import torch
import torch.nn as nn
from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构

# 创建神经网络模型类
class Model(nn.Module):
    # 初始化模型的构造函数
    def __init__(self):
        super().__init__()  # 调用父类 nn.Module 的初始化方法
        # 定义第一个全连接层(线性层),3个输入特征,3个输出特征
        self.linear1 = nn.Linear(3, 3)  
        # 使用 Xavier 正态分布初始化第一个全连接层的权重
        nn.init.xavier_normal_(self.linear1.weight)
        
        # 定义第二个全连接层,输入 3 个特征,输出 2 个特征
        self.linear2 = nn.Linear(3, 2)
        # 使用 Kaiming 正态分布初始化第二个全连接层的权重,适合 ReLU 激活函数
        nn.init.kaiming_normal_(self.linear2.weight)
        
        # 定义输出层,输入 2 个特征,输出 2 个特征
        self.out = nn.Linear(2, 2)
        
    # 定义前向传播过程 (forward 函数会自动执行,类似于模型的"推理"过程)
    def forward(self, x):
        # 第一个全连接层运算
        x = self.linear1(x)
        # 使用 Sigmoid 激活函数
        x = torch.sigmoid(x)
        
        # 第二个全连接层运算
        x = self.linear2(x)
        # 使用 ReLU 激活函数
        x = torch.relu(x)
        
        # 输出层运算
        x = self.out(x)
        # 使用 Softmax 激活函数,将输出转化为概率分布
        # dim=-1 表示在最后一个维度(通常是输出的类别维度)上做 softmax 归一化
        x = torch.softmax(x, dim=-1)
        return x
    
if __name__ == '__main__':
    # 实例化神经网络模型
    my_model = Model()
    
    # 随机生成一个形状为 (5, 3) 的输入数据,表示 5 个样本,每个样本有 3 个特征
    my_data = torch.randn(5, 3)
    print("mydata shape", my_data.shape)
    
    # 通过模型进行前向传播,输出模型的预测结果
    output = my_model(my_data)
    print("output shape", output.shape)
    
    # 计算并显示模型的参数总量以及模型结构
    summary(my_model, input_size=(3,), batch_size=5)
    
    # 查看模型中所有的参数,包括权重和偏置项(bias)
    print("-----查看模型参数w 和 b  -----")
    for name, parameter in my_model.named_parameters():
        print(name, parameter)

mydata shape torch.Size([5, 3])
output shape torch.Size([5, 2])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                     [5, 3]              12
            Linear-2                     [5, 2]               8
            Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
-----查看模型参数w 和 b  -----
linear1.weight Parameter containing:
tensor([[ 0.4777, -0.2076,  0.4900],
        [-0.1776,  0.4441,  0.6924],
        [-0.5449,  1.6153,  0.0243]], requires_grad=True)
linear1.bias Parameter containing:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)
linear2.weight Parameter containing:
tensor([[-0.0510, -1.2731, -0.7253],
        [-0.6112,  0.1189, -0.4903]], requires_grad=True)
linear2.bias Parameter containing:
tensor([0.5391, 0.2552], requires_grad=True)
out.weight Parameter containing:
tensor([[-0.3271, -0.3483],
        [-0.0619, -0.0680]], requires_grad=True)
out.bias Parameter containing:
tensor([-0.5508,  0.5895], requires_grad=True)
 

 代码输出结果解读

​​​​​​​

这个代码的输出展示了两部分内容:

  1. 数据维度和模型输出维度

    • mydata shape torch.Size([5, 3])

    • output shape torch.Size([5, 2])

  2. 模型的结构、参数数量和每一层的权重与偏置

    • 模型的层结构、每一层的输出形状,以及每一层的参数数量。

    • 每层的权重(weight)和偏置(bias)的具体数值。

让我们详细分析每一部分的输出。

1. 输入数据和输出数据的形状

mydata shape torch.Size([5, 3])

这部分的输出说明:

  • 输入数据的形状(5, 3),表示有 5 个样本,每个样本有 3 个特征。这与模型定义时的输入层 nn.Linear(3, 3) 是一致的,输入层期望接收 3 个特征。

output shape torch.Size([5, 2])

这部分的输出说明:

  • 模型输出的形状 (5, 2),表示 5 个样本的输出,每个样本的输出有 2 个值。由于模型的输出层定义为 nn.Linear(2, 2),它接收 2 个输入特征并输出 2 个值,符合预期。

2. 模型结构和参数

模型结构和参数信息是通过 summary() 函数生成的,它列出了每一层的名称、输出形状和参数数量。

详细输出解释:
----------------------------------------------------------------
      Layer (type)               Output Shape         Param #
================================================================
          Linear-1                     [5, 3]             12
          Linear-2                     [5, 2]               8
          Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
线性层 1(Linear-1
  • 层的类型Linear,这是一个全连接层,定义为 nn.Linear(3, 3)

  • 输出形状[5, 3],表示输入了 5 个样本,每个样本有 3 个特征,经过该层的输出仍然是 5 个样本,每个样本有 3 个特征。

  • 参数数量:12,其中 9 个是权重参数(3 x 3 的权重矩阵),另外 3 个是偏置项。

线性层 2(Linear-2
  • 层的类型Linear,定义为 nn.Linear(3, 2),将 3 个输入特征映射到 2 个输出特征。

  • 输出形状[5, 2],表示输入了 5 个样本,每个样本有 2 个输出特征。

  • 参数数量:8,其中 6 个是权重参数(3 x 2 的权重矩阵),另外 2 个是偏置项。

输出层(Linear-3
  • 层的类型Linear,定义为 nn.Linear(2, 2),接收 2 个输入特征,输出 2 个特征。

  • 输出形状[5, 2],表示 5 个样本,每个样本的输出为 2 个特征。

  • 参数数量:6,其中 4 个是权重参数(2 x 2 的权重矩阵),另外 2 个是偏置项。

参数统计
  • 总参数数量:26,模型中所有可训练参数(包括权重和偏置)的总数量。

  • 可训练参数:26,模型中所有参与训练的参数。这里所有的参数都是可训练的(requires_grad=True),没有非可训练的参数。

  • 非可训练参数:0,说明模型中没有被设置为不可训练的参数。

3. 查看每一层的权重和偏置

这一部分输出列出了每一层的具体参数(权重和偏置)的值。

linear1.weight:
tensor([[ 0.4777, -0.2076, 0.4900],
      [-0.1776, 0.4441, 0.6924],
      [-0.5449, 1.6153, 0.0243]], requires_grad=True)

这是 linear1 层的权重矩阵,形状是 (3, 3)。由于 linear1nn.Linear(3, 3),它的权重矩阵也是 3 行 3 列。权重参数是使用 Xavier 初始化(nn.init.xavier_normal_)初始化的。

linear1.bias:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)

这是 linear1 层的偏置项,形状是 (3,),因为每个输出特征对应一个偏置值。

linear2.weight:
tensor([[-0.0510, -1.2731, -0.7253],
      [-0.6112, 0.1189, -0.4903]], requires_grad=True)

这是 linear2 层的权重矩阵,形状是 (2, 3),因为 linear2nn.Linear(3, 2),需要 3 个输入特征映射到 2 个输出特征。权重是使用 Kaiming 初始化nn.init.kaiming_normal_初始化的。

linear2.bias:
tensor([0.5391, 0.2552], requires_grad=True)

这是 linear2 层的偏置项,形状是 (2,),因为每个输出特征对应一个偏置值。

out.weight:
tensor([[-0.3271, -0.3483],
      [-0.0619, -0.0680]], requires_grad=True)

这是输出层 out 的权重矩阵,形状是 (2, 2),因为 outnn.Linear(2, 2),接收 2 个输入特征并输出 2 个特征。

out.bias:
tensor([-0.5508, 0.5895], requires_grad=True)

这是输出层 out 的偏置项,形状是 (2,)

总结

  • 这段代码展示了一个简单的神经网络模型,包含 3 个全连接层(线性层),每层的输入输出特征数量逐步缩小。

  • 我们通过 summary() 查看了模型的整体结构,展示了每一层的输出形状和参数数量,总共有 26 个参数。

  • 每一层的权重和偏置参数值被输出,展示了它们是如何被初始化的(通过 Xavier 和 Kaiming 初始化)。

  • 该模型的前向传播通过激活函数(sigmoidReLU)以及 softmax 将输出转化为概率分布。

​​​​​​​​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2160047.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ComfyUI】控制光照节点——ComfyUI-IC-Light-Native

原始代码(非comfyui):https://github.com/lllyasviel/IC-Light comfyui实现1(600星):https://github.com/kijai/ComfyUI-IC-Light comfyui实现2(500星):https://github.c…

【QT】QSS基础

欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:QT 目录 👉🏻基本语法👉🏻从⽂件加载样式表👉🏻选择器伪类选择器 👉&…

动手学深度学习9.1. 门控循环单元(GRU)-笔记练习(PyTorch)

本节课程地址:门控循环单元(GRU)_哔哩哔哩_bilibili 本节教材地址:9.1. 门控循环单元(GRU) — 动手学深度学习 2.0.0 documentation (d2l.ai) 本节开源代码:...>d2l-zh>pytorch>chap…

K8S服务发布

一 、服务发布方式对比 二者主要区别在于: 1. 部署复杂性:传统的服务发布方式通常涉及手动配置 和管理服务器、网络设置、负载均衡等,过程相对复 杂且容易出错。相比之下,Kubernetes服务发布方式 通过使用容器编排和自动化部署工…

高灵敏度电容式触摸IC在弹簧触控按键中的应用

电容式触摸IC-弹簧触控按键-是通过检测人体与传感器之间的电容变化来实现触摸控制。这种技术具有高灵敏度、稳定性好、防水性强等优点,广泛应用于家用电器、消费电子、工业控制等领域。 弹簧触控按键的特点: 1. 高灵敏度:即使隔着绝缘材料&a…

Java语言的Springboot框架+云快充协议1.5+充电桩系统+新能源汽车充电桩系统源码

介绍 云快充协议云快充1.5协议云快充协议开源代码云快充底层协议云快充桩直连桩直连协议充电桩系统桩直连协议 有需者咨询,非诚勿扰; 软件架构 1、提供云快充底层桩直连协议,版本为云快充1.5,对于没有对接过充电桩系统的开发者…

[vulnhub] Jarbas-Jenkins

靶机链接 https://www.vulnhub.com/entry/jarbas-1,232/ 主机发现端口扫描 扫描网段存活主机,因为主机是我最后添加的,所以靶机地址是135的 nmap -sP 192.168.75.0/24 // Starting Nmap 7.93 ( https://nmap.org ) at 2024-09-21 14:03 CST Nmap scan…

求职Leetcode题目(11)

1.最长连续序列 解题思路: 方法一: • 首先对数组进行排序,这样我们可以直接比较相邻的元素是否连续。• 使用一个变量 cur_cnt 来记录当前的连续序列长度。• 遍历排序后的数组: 如果当前元素与前一个元素相等,则跳过&#xf…

文档矫正算法:DocTr++

文档弯曲矫正(Document Image Rectification)的主要作用是在图像处理领域中,对由于拍摄、扫描或打印过程中产生的弯曲、扭曲文档进行校正,使其恢复为平整、易读的形态。 一. 论文和代码 论文地址:https://arxiv.org/…

AI辅助编码工具如何影响着程序员开发群体

AI辅助编码工具的出现对程序员开发群体产生了深远的影响,有一些初步基础的程序员,可以借助AI工具的加持,生产效率大大提升,达到中高级程序员的水平。 这些影响可以从多个角度来分析: 提高开发效率: AI工具…

跳蚤市场小程序|基于微信小程序的跳蚤市场(源码+数据库+文档)

跳蚤市场小程序目录 基于微信小程序的饮品点单系统的设计与实现 一、前言 二、系统功能设计 三、系统实现 管理员功能实现 商品信息管理 商品订单管理 论坛管理 用户管理 5.1.5 新闻信息管理 用户功能实现 四、数据库设计 1、实体ER图 2、具体的表设计如下所示&a…

毕业设计选题:基于ssm+vue+uniapp的英语学习激励系统小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

STM32引脚输入

文章目录 前言一、看原理图二、开始编程1.开启时钟2.配置GPIOA.0 上拉输入3.读取 GPIOA.0 引脚 GPIOA_IDR 0位上是1(按键松开),输入就是高电平,否则就是低电平(按键按下) 三、完整程序四 测试效果总结 前言…

Spring MVC的应用

目录 1、创建项目与maven坐标配置 2、核心配置 3、启动项目测试 4、不同请求参数在controller的配置 4.1 servlet API 4.2 简单类型 4.3 pojo类型 4.4 日期类型 4.5 restful风格4种操作类型 4.5.1 GET:获取资源 4.5.2 POST:新建资源 4.5.3 P…

【Godot4.3】自定义数列类NumList

概述 数列是一种特殊数组。之前写过等比、等差数列、斐波那契等数列的求取函数。今天就汇总到一起,并添加其他的一些数列,比如平方数、立方数、三角形数等。 这里我首先采用以前比较喜欢的静态函数库的写法,然后在其基础上改进为基于类继承…

基于飞腾平台的OpenCV的编译与安装

【写在前面】 飞腾开发者平台是基于飞腾自身强大的技术基础和开放能力,聚合行业内优秀资源而打造的。该平台覆盖了操作系统、算法、数据库、安全、平台工具、虚拟化、存储、网络、固件等多个前沿技术领域,包含了应用使能套件、软件仓库、软件支持、软件适…

ChatGPT 推出“Auto”自动模式:智能匹配你的需求

OpenAI 最近为 ChatGPT 带来了一项新功能——“Auto”自动模式,这一更新让所有用户无论使用哪种设备都能享受到更加个性化的体验。简单来说,当你选择 Auto 模式后,ChatGPT 会根据你输入的提示词复杂程度,自动为你挑选最适合的AI模…

解密 Python 的 staticmethod 函数:静态方法的全面解析!

更多Python学习内容:ipengtao.com 在 Python 中,staticmethod 函数是一种装饰器,用于将函数转换为静态方法。静态方法与实例方法和类方法不同,它们不需要类实例作为第一个参数,也不需要类作为第一个参数,因…

只用几行代码,不依赖任何框架?SMTFlow 轻松实现前端流程图

只用几行代码,不依赖任何框架?SMTFlow 轻松实现前端流程图! 在前端开发中,如果你需要一个简单好用的流程图设计工具,SMTFlow 绝对是你的不二之选!本文将介绍 SMTFlow 的核心功能、特点以及如何快速上手。 工…

C++中set和map的使用

1.关联式容器 序列式容器里存储的是元素本身&#xff0c;如vector、list、deque 关联式容器即&#xff0c;容器中存储<key&#xff0c;value>的键值对&#xff0c;树型结 构的关联式容器主要有四种&#xff1a;map、set、multimap、multiset。他们都使用平衡搜索树(即红…