通过类定义一个网络

news2024/11/25 11:51:27
import torch
from torch import nn

x = torch.ones(2,10)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.out = nn.Linear(10, 1)
    def forward(self,x):
        return self.out(x)

 1. 代码解析

  • 如何定义一个类?self 又是什么东西?
  • 类是如何继承基类的特性的?nn.module 是个什么对象?
  • 为什么会有一个初始化函数 init,初始化函数中的 super().init 函数是做什么用的?是否必须要有?
  • forward 函数有什么作用?该怎样用?
  • 为什么初始化函数前后有两条下划线?forward 前后为什么没有下划线?
  • nn.Linear 函数是干什么用的?
  • 输入输出张量的形状大小都是怎么对应的?
  • 模型内部的网络参数是怎么定义的,如何查看?

1.1 如何定义一个类 

class MLPSimple(nn.Module):
    def __init__(self):
        super().__init__()
        self.out = nn.Linear(10, 1)

a_s = MLPSimple()
a_s

 self 是指向类自身的一个指针,可以通过该指针引用类自身的成员,默认这个参数是每个成员函数的首个输入参数,如果没有self参数,那么定义的函数将无法引用类自身的成员

1.2 如何继承基类的特性?nn.module 是个什么对象?

定义类的时候,将需要被继承的类的名称作为参数传入,如 class MLPSimple(nn.Module) 这样就是定义了一个新的类 MLPSimple ,这个新类继承了 nn.Module 的所有特性。 以下展示再创建一个类,继承刚刚创建的新类 MLPSimple_p。

class MLPSimple_p(MLPSimple):
    def __init__(self):
        super().__init__()
a_s_p = MLPSimple_p()
a_s_p

nn.Module是PyTorch中的一个类,继承自torch.nn的基类,用于定义神经网络模型、提供前向传播过程所需的基本功能和方法。 在PyTorch中,神经网络模型通常是由多个层组成的,每个层都是一个nn.Module实例。通过继承nn.Module类并实现自己的forward方法,可以定义自己的神经网络模,。在神经网络的训练和推理过程中,PyTorch会自动调用nn.Module的forward方法来计算输出。

1.3 为什么会有一个初始化函数 init,初始化函数中的 super().init 函数是做什么用的?是否必须要有? 

与C++不同,我们自己新定义的python类没有显式的构造函数(python 类有自己的构造函数,该构造函数跟 init 函数一样也是个魔法函数),python类的对init函数的调用,可以被看做是类似于C++类调用构造函数类似的过程,当python中通过类创建对象的时候就会调用init函数对对象进初始化,与c++不同的是如果c++继承了基类,那么构造对象的时候会隐式的调用基类的构造方法,这里python却需要显示的主动调用基类初始化方法super().init()对基类的特性进行初始化。这里的显示调用时必须的,如果漏掉会报错。

1.4 forward 函数有什么作用?该怎样用?

python 的类成员方法中有一个非常特殊的函数叫做 call() 函数,这个函数使得实例化的对象自身可以像一个函数一样被调用,如同样实例对象为 a_s_p ,如果这个对象是在C++中,那么这个对象就单纯的是一个对象而已,要想让这个对象处理一些事情,就必须通过对象去调用它自身的一些方法来实现,如a_s_p.func(),但是在python类中,类定义里面有一个特别的函数叫做call函数,这个函数可以使被实例的对象本身像一个函数一样被直接调用,而在pytorch中这个call函数会默认直接调用创建类的forward函数,forward函数会接受所有传递给call函数的参数,call函数本身也会将forward函数的返回结果直接返回,因此就形成了pytorch中这种可以直接通过对象本身来处理数据的现象。

a_s_p(inputs) 隐含的意思就是 a_s_p.call(inputs) , 而 a_s_p.call(inputs) 本身的定义却类似于以下这样:

class MLPSimple(nn.Module):
    def __call__(self,inputs):
        return self.forward(inputs)
    def forward(self,inputs)
        return outputs

1.5 为什么初始化函数前后有两条下划线?forward 前后为什么没有下划线?

函数前后有两条下划线的方法叫做python的魔法函数,魔法函数本身是指的到了特定状况下会自动被调用的函数,因为其自适应性像魔法一样神奇所以被称为魔法函数,没有下划线的函数指的是普通函数,像forward函数的名称是pytorch的保留字,默认被call函数调用,但它仍然跟普通函数一样,没什么特别之处。其他的魔法函数还有如下这些:

  • init():类的初始化方法,在创建类的实例时自动调用。
  • new():类的构造函数,当使用类的构造函数创建新的类实例时自动调用。
  • str():返回对象的字符串表示,当调用print()函数输出对象时自动调用。
  • del():在对象被删除时自动调用。
  • call():当对象被作为函数调用时自动调用。
  • len():返回对象的长度,当使用len()函数调用对象时自动调用。
  • eq():比较两个对象是否相等,当使用==运算符比较两个对象时自动调用。
  • hash():返回对象的哈希值,当使用hash()函数调用对象时自动调用。
  • getitem():当使用方括号运算符[]访问对象的元素时自动调用。
  • setitem():当使用方括号运算符[]修改对象的元素时自动调用。的元素时自动调用。[]修改对的元素时自动调用。

1.6 nn.Linear 函数是干什么用的?

对输入向量进行线性变换的一个网络层类,,与之类似的类还有以下几个类: 'Bilinear(双线性变换' 'Identity(占位符) 'LazyLinear'(系数矩阵尺寸在第一次被调用时候自动初始化,不需要主动指定)

class Linear(Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,device=None, dtype=None) -> None:
        ...

   def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

详细介绍见官方文档:Linear — PyTorch 2.0 documentation

初始化时候指定input_tensor[size_in_0,in_features] \output_tensor[size_out_0,out_features],即指定了线性层的系数矩阵尺寸 weight[in_features,out_features],计算时候 out_tensor = input_tensor * weight = [size_in_0,in_features] * [in_features,out_features] = [size_in_0 ,out_features] 在开头的例子中即为 out_tensor = input_tensor * weight = [2,10] * [10,1] = [2 ,1] 以上计算过程中可以发现,矩阵的最后一个维度是样本的特征维度,比如说线性变换中的自变量个数即为 in_features = 10 , 因变量的个数为 out_features = 1 ,这两个个数即为单个样本的特征维度或者说是特征数。倒数第二个维度是样本的批量大小,像本例中输入样本为2,输出样本自然的对应也应该是2,输入样本的数量不需要单独指定,在传入模型处理的时候,模型会自动去识别处理。

1.7 模型内部的网络参数是怎么定义的,如何查看?

访问权重系数 :print(a.out.weight)

访问偏置系数 :print(a.out.bias)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/953413.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高尔夫APP外包开发主要功能

高尔夫小程序可以实现教练预约、场地预地、训练课程、积分系统、社交功能等,通过小程序方便用户,同时也提球场的管理能力。今天和大家分享一些主要功能和注意的问题,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包…

详细说明OSPF常见的LSA

目录 1类LSA (Router LSA)介绍 总结:1类LSA 2类LSA (Network LSA)介绍 总结:2类LSA 3类LSA (Summary LSA)介绍 总结:3类LSA 5类LSA (ase LSA&…

二肽-2——祛除眼部水肿和眼部黑眼圈

简介 眼袋形成的一个重要的原因是水肿, 诱因主要是淋巴循环减弱和毛细血管的通透性增加。 INCI 名称 二肽-2 多肽序列 VW CAS号 24587-37-9 机理 抑制血管紧张素转换酶,增强眼部淋巴循环,促进水分排出 二肽-2是一种二胜肽,带有二种标…

高忆管理:沪指弱势调整跌0.53%,地产板块走弱,光刻机概念拉升

31日早盘,A股两市弱势调整。截至午间收盘,沪指跌0.53%报3120.39点,深成指跌0.55%,创业板指跌0.54%,两市算计成交5291亿元。北向资金净流出36亿元。盘面上,半导体、中成药、黄金等板块走强,地产、…

生成式人工智能能否使数字孪生在能源和公用事业行业成为现实?

推荐:使用 NSDT场景编辑器 快速搭建3D应用场景 克服障碍,优化数字孪生优势 要实现数字孪生的优势,您需要数据和逻辑集成层以及基于角色的演示。如图 1 所示,在任何资产密集型行业(如能源和公用事业)中&…

高忆管理:A股上市券商“中考”成绩放榜,最大黑马是它

A股上市券商2023年半年报发表8月30日晚正式收官。全体上看,43家券商中有10家营收超百亿元,多达30家完成了营收及净利润的双增。头部券商中,我国银河近年来运营成绩排名稳步提高;区域性券商中,天风证券成最大黑马&#…

iOS逆向进阶:iOS进程间通信方案深入探究与local socket介绍

在移动应用开发中,进程间通信(Inter-Process Communication,IPC)是一项至关重要的技术,用于不同应用之间的协作和数据共享。在iOS生态系统中,进程和线程是基本的概念,而进程间通信方案则为应用的…

行政固定资产应该怎么管理

行政需要管理的固定资产主要包括办公设备、交通工具、通讯设备、家具等。具体来说,行政需要管理的固定资产包括但不限于:电脑、打印机、传真机、复印机、投影仪、电话、传真机、传真纸、电话线、路由器、交换机、服务器、UPS电源、办公桌椅、沙发等。 行…

Java小项目【图书馆系统】

一、设计图书馆系统 Java是一个面向对象的语言,在编写代码的之前,我们要先确定有哪些对象 图书馆,首先有很多书,还有书架来放置这些书。然后是对书进行操作的人,比如普通用户和管理员。最后是对关于书的各种操作&#…

如何检测勒索软件攻击

什么是勒索软件 勒索软件又称勒索病毒,是一种特殊的恶意软件,又被归类为“阻断访问式攻击”(denial-of-access attack),与其他病毒最大的不同在于攻击方法以及中毒方式。 攻击方法:攻击它采用技术手段限制…

软件系统第三方检测费标准

收费标准 软件系统第三方检测收费标准: 行业内对于第三方软件测试报告并没有一个明确的收费标准,不同地域之间的收费不同,各个检测单位的报价也略有差异。第三方检测报告的收费标准需要根据具体的测试需求而定,一般是按照项目大…

“算力+运力”扇动双翼,制造算力时代的蝴蝶效应

8月18日-20日,第二届中国算力大会在宁夏银川成功举办。 今年以来,随着大模型、AIGC等新技术的火爆,站在舞台中央的算力承载了无尽的期待,发展数字经济需要以算力基础设施为前提,社会各界已经形成了共识。 与此同时&…

一文速学-让神经网络不再神秘,一天速学神经网络基础(五)-最优化

前言 思索了很久到底要不要出深度学习内容,毕竟在数学建模专栏里边的机器学习内容还有一大半算法没有更新,很多坑都没有填满,而且现在深度学习的文章和学习课程都十分的多,我考虑了很久决定还是得出神经网络系列文章,…

Kafka系列三基础概念

文章首发于个人博客,欢迎访问关注:https://www.lin2j.tech Kafka 是一款分布式消息发布和订阅系统,其高性能、高吞吐量的特点决定了其适用于大数据传输场景。 基础概念 Broker Broker 其实就是一个运行 Kafka 服务的服务器。Kafka 集群包…

chatGPT训练过程

强化学习基础 强化学习是指智能体在不确定环境中最大化其获得的奖励从而达到自主决策的目的。其执行过程为:智能体依据策略决策从而执行动作,然后感知环境获取环境的状态,进而得到奖励(以便下次再到相同状态时能采取更优的动作),…

(java)进程和线程的联系和区别 。Java如何进行多线程编程?Thread 类及常见方法。

目录 进程 1.进程具有独立性 ———— 虚拟地址空间 线程 为什么要引入多个线程? 多线程注意点 ⁜⁜总结:线程和进程的区别和联系⁜⁜ (经典面试题) Java如何进行多线程编程? 创建线程 ——方法1 继承 Thre…

webrtc 的Bundle group 和RTCP-MUX

1,最近调试程序的时候发现抱一个错误 max-bundle configured but session description has no BUNDLE group 最后发现是一个参数设置错误 config.bundle_policy webrtc::PeerConnectionInterface::BundlePolicy::kBundlePolicyMaxBundle; 2,rtcp-mu…

SpringBoot项目,执行install命名时,控制台显示:Unable to find main class

构建springboot多模块项目,启动时可以正常启动,执行了父工程的maven的clean也没问题,执行install的时候就报错了:Unable to find main class。显而易见 这个错是找不到主类。 记录下解决过程: 首先看自己项目的父工程…

膦酸基官能团高盐环境下去除钙镁离子树脂

项目名称 某新能源公司除钙镁项目 工艺选择 串联运行 工艺原理 膦酸基官能团高盐环境下去除钙镁离子 项目背景 锂及其盐类是国民经济和国防建设中具有重要意义的战略物资,也是与人们生活息息相关的能源材料。而碳酸Li作为锂盐的基础盐,是制取锂化…

Matlab 基本教程

1 清空环境变量及命令 clear all % 清除Workspace 中的所有变量 clc % 清除Command Windows 中的所有命令 2 变量命令规则 (1)变量名长度不超过63位 (2)变量名以字母开头, 可以由字母、数字和下划线…