手撕Pytorch源码#3.Dataset类 part3

news2025/1/17 23:20:54

写在前面

  1. 手撕Pytorch源码系列目的:

  • 通过手撕源码复习+了解高级python语法

  • 熟悉对pytorch框架的掌握

  • 在每一类完成源码分析后,会与常规深度学习训练脚本进行对照

  • 本系列预计先手撕python层源码,再进一步手撕c源码

  1. 版本信息

python:3.6.13

pytorch:1.10.2

  1. 本博文涉及python语法点

  • 泛型类Union和Optional

  • __getattr__方法

  • Iterable,Iterator和forloop

  • functools.partial

  • MRO与C3算法

目录

[TOC]

零、流程图

一、IterableDataset

1.0 源代码
class IterableDataset(Dataset[T_co], metaclass=_DataPipeMeta):
    functions: Dict[str, Callable] = {}
    # Optional也是泛型编程的常用函数,表示
    reduce_ex_hook : Optional[Callable] = None
	# __iter__方法说明此类是Iterable可迭代对象
    # 而__iter__函数返回的是Iterattor迭代器对象
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError
	# __add__函数在Dataset类中同样出现了,用于数据集的拼接
    # Dataset中的__add__方法是通过ConcatDataset来实现的 
    def __add__(self, other: Dataset[T_co]):
        # ChainDataset的源码分析见下一篇博文
        return ChainDataset([self, other])

    # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]

    def __getattr__(self, attribute_name):
        # 注意IterableDataset.functions与self.functions是不同的
        # 前者是调用类属性,后者是调用对象属性
        # 根据前面functions的定义,其为类属性
        if attribute_name in IterableDataset.functions:
            function = functools.partial(IterableDataset.functions[attribute_name], self)
            return function
        else:
            raise AttributeError

    def __reduce_ex__(self, *args, **kwargs):
        if IterableDataset.reduce_ex_hook is not None:
            try:
                return IterableDataset.reduce_ex_hook(self)
            except NotImplementedError:
                pass
        return super().__reduce_ex__(*args, **kwargs)

    @classmethod
    def set_reduce_ex_hook(cls, hook_fn):
        if IterableDataset.reduce_ex_hook is not None and hook_fn is not None:
            raise Exception("Attempt to override existing reduce_ex_hook")
        IterableDataset.reduce_ex_hook = hook_fn
1.1 reduce_ex_hook : Optional[Callable] = None
  1. Optional[Callable]:Optional也是泛型编程的重要函数,与Union,Generic等类常出现在程序中

  1. Union[int,str]表示可能的类型范围是int以及str,因而Union类表示类比的或操作

  1. Optional[Callable]相当于输入类Callable与None类的结合,即Union[Callable,None]

4.泛型编程概念见博文**[【手撕Pytorch源码#1.Dataset类 part1】]((12条消息) 手撕Pytorch源码#1.Dataset类 part1_望 尘�的博客-CSDN博客)**

1.2 def __iter__(self) -> Iterator[T_co]
  1. __iter__函数标志该类是可迭代对象Iterable,关于可迭代对象Iterable和迭代器Iterator以及最常用的for循环的原理见【2.1节 Iterable与Iterator和for loop】

  1. 此处的__iter__和Dataset类的__iter__方法一样都需要自己实现,否则就会报错NotImplementedError

1.3 def __getattr__(self, attribute_name)
  1. __getattr__方法用于当对象效用的属性或方法无法找到时,解释器便会调用__getattr__函数

1.4 function = functools.partial(IterableDataset.functions[attribute_name], self)
  1. functions.partial可以给固定函数传入相应的值,精讲见【2.2节 functools.partial】

1.5 __reduce_ex__(self, *args, **kwargs)与@classmethod
  1. 由于本期内容较为硬核,因而关于__reduce__,__reduce_ex__,*args,**kwargs和@classmethod放到下一期进行精讲

二、相应的Python语法补充

2.1 Iterable与Iterator和for loop
  1. 前文源代码中出现了__iter__方法,声明该类是Iterable可迭代对象,而__iter__方法返回的则是一个Iterator迭代器对象,因而趁此机会研究一下Iterable和Iterator的区别

  1. 为了比较两者的区别,我用python实现了链表的数据结构,代码如下:

# 用python实现链表的数据结构

class NodeIterator():
    def __init__(self,node:'Node') -> None:
        # Iterator必须要储存当下的状态,也就是现在调用到哪一位
        # 有点像C语言的指针
        # 下面的self.current_node就是储存当前状态的
        self.current_node = node
    def __next__(self):
        if self.current_node is None:
            raise StopIteration
        node,self.current_node = self.current_node,self.current_node.next
        return node
    # python官方要求Iterator对象也必须定义__iter__方法,原因见下①
    def __iter__(self):
        return self

class Node():
    def __init__(self,data) -> None:
        # self.data是链表结点存的数据
        # 可迭代对象Iterable更像是一个数据的容器,而不太在乎当前数据迭代的对象
        # 下面self.data其实就是承装了数据,起到container的作用
        self.data = data
        # self.next是链表结点的next指针
        self.next = None
    # 我要让链表是一个可迭代对象,必然需要__iter__
    def __iter__(self):
        return NodeIterator(self)

node1 = Node("Node1")
node2 = Node("Node2")
node3 = Node("Node3")
node1.next = node2
node2.next = node3
# 如果有人希望直接从node1链表中的第二个元素开始遍历,会写出以下代码
it = iter(node1)
first = next(it)
print(first.data)
# 如果这里不在Iterator中定义__iter__函数,那么下面的代码就会报错
for node in it:
    print(node.data)
  • 在上述代码中Node是可迭代对象Iterable,NodeIterator是迭代器对象Iterator

  • 对比两个类,迭代器对象于可迭代对象的最大区别为:

  • Iterable对象更像是一个数据容器container,能够承装数据,如常见的数据结构list,tuple,dict都是可迭代对象

  • 而Iterator则不需要保存数据,而需要保存状态,即当前迭代到哪一个数据为,上述代码中,class NodeIterator里的self.current_node就是用于保存当前迭代到的结点

  • 同时,从类程序上看,定义了__iter__方法就可以成为Iterable对象,定义了__next__方法就可以成为Iterator对象

  • Iterable类与Iterator类定义的其他注意事项

  • Iterable类中__iter__返回的是一个迭代器Iterator对象

  • Iterator中也必须定义__iter__函数,保证其也是一个Iterable对象,而其__iter__函数一般直接return self即可,如果在Iterator中不定义__iter__函数,则有可能出现错误(见上述代码的注释)

  • Iterator类的__next__函数,需要判断迭代是否结束,如果结束,需要raise StopIteration以标志迭代结束

  1. for loop的运作过程

  • 首先程序会判断for .. in x中的x是否为可迭代对象,如果不是,直接报错

  • 在运行for循环之前,程序会首先将可迭代对象Iterable通过iter(Iterable)调用其__iter__方法生成迭代器Iterator,在通过迭代器逐步取值

  • 可以查看以下for循环的字节码,便可以直观了解上述过程

# 查看以下for循环的字节码Bytecode
import typing
import dis
def for_func(lst:typing.List[int])->None:
    for num in lst:
        print(num)
dis.dis(for_func)

# 上述代码的字节码如下:
#  49           0 LOAD_FAST                0 (lst)
#               2 GET_ITER
#         >>    4 FOR_ITER                12 (to 18)
#               6 STORE_FAST               1 (num)

#  50           8 LOAD_GLOBAL              0 (print)
#              10 LOAD_FAST                1 (num)
#              12 CALL_FUNCTION            1
#              14 POP_TOP
#              16 JUMP_ABSOLUTE            4
#         >>   18 LOAD_CONST               0 (None)
#              20 RETURN_VALUE
  • 上述字节码中49-2 GET_ITER就是从Iterable中取出对应的Iterator

  • 关于字节码ByteCode的相关理论,等有空再开一期专门研究Cpython源码

2.2 functools.partial
  1. functools.partial()用于给函数传递参数,并且返回传参后的函数:该函数第一个参数为函数名,后面的参数为需要对函数传入的参数值

  1. 下面是functools.partial的实用场景,直接上代码👇

import functools
import typing
def display_age(age:int)->None:
    print("your age:{}".format(age))

def display_height(height:float)->None:
    print("your height:{} cm".format(height))

def callback_fn(callback:typing.Callable)->None:
    print("That's where functools.partial works!")
    callback()

# 由于callback_fn的输入是一个函数,且该函数没有参数列表
# 因此需要提前对其进行传参
d_age = functools.partial(display_age,30)
d_height = functools.partial(display_height,235.62)

callback_fn(d_age)
callback_fn(d_height)

# 输出结果为:
# That's where functools.partial works!
# your age:30
# That's where functools.partial works!
# your height:235.62 cm
  1. 在上述代码中,由于callback_fn的输入参数为一个函数,且其没有默认参数,因而需要提前对该函数进行传参。可能你有这样的问题:那么下面这种写法不久好了?👇

def callback_fn(callback:typing.Callable,age:int)->None:
    print("That's where functools.partial works!")
    callback(int)
  • 的确,如果仅对于display_age函数,这样写确实可以,但如果callback_fn的参数为多个不同输入值的函数,那么这种写法就必然会造成极大的麻烦,functools.partial就有比较大的优越性

2.3 MRO
  1. mro:Method Resolution Order(方法解析顺序),即一个子类,其父类函数的优先级顺序链

2.4 C3算法
  1. 本博文最硬核的部分来了,先亮出1996年原论文,干王可以手撕原论文[A monotonic Superclass Linearization for Dylan](A monotonic superclass linearization for Dylan (acm.org))

  1. 本博文仅就C3算法的三个假设以及计算方法进行阐述,由于概念较为抽象,因而尽量采用图与代码对应的形式进行呈现

2.4.1 C3算法的三个假设
  1. preservation of local precedence order局部优先顺序

  • 先上代码:

class A:
    def display(self):
        print("A")
class B(A):
    def display(self):
        print("B")
class C(A):
    pass
class D(C,B):
    pass

d = D()
d.display()

# 输出结果为:
# B
  • 再看继承结构图和MRO链

  • 局部优先顺序指的是D类同时继承C类和B类,程序代码为class D(C,B),因此,依照此顺序,在D以及其所有子类的MRO链中,C类一定排在B类的前面

  1. fitting a monotonically criterion单调性准则

  • 单调性准则的描述:子类的MRO链选择必须来自其直接父类,而不能是其他的选择

  • 单调性准则引用原论文中的例子

  • 首先对<pedalo>类进行分析,可以看见,<pedalo>类是<pedal-wheel-boat>和<small-catamaran>的直接子类,因此<pedalo>类的MRO直接选择必须来自<pedal-wheel-boat>和<small-catamaran>两类之一

  • 而观察<pedal-wheel-boat>类和<small-catamaran>类的MRO链可以发现:<pedal-wheel-boat>类中<day-boat>类排序高于<wheel-boat>类,且<small-catamaran>类中MRO链没有<wheel-boat>类

  • 但是观察<pedalo>类的MRO链可以发现:<wheel-boat>类排序高于<day-boat>类。因此,如果一个类函数仅存在于<day-boat>类和<wheel-boat>类中,那么<pedal-wheel-boat>类和<small-catamaran>类将会执行<day-boat>类的函数,而<pedalo>类将会执行<wheel-boat>类的函数,与单调性原则不符

  1. a consistent extended precedence graph拓展优先图

  • 用于解决一个类的子类和其父类的优先级顺序

  • 抽象的表达:取决于两个的最小公共子类上,两类或其子类的优先级顺序

  • 同样用论文中的例子进行演示

  • 根据local precedence原则,对于<editable-scrollable-pane>类而言,<scrollabel-pane>类排在<editable-pane>类之前。对于<scrollabel-pane>类而言,<pane>类排在<scrolling-mixin>类之前。对于<editable-pane>类而言,<pane>类排在<editable-mixin>类之前。

  • 但是我们希望能够对<scrolling-mixin>类和<editable-mixin>类进行排序

  • 首先我们找到<scrolling-mixin>类和<editable-mixin>类的最小公共子类<editable-scrollable-pane>类

  • 在从<editable-scrollable-pane>类开始依次比较<scrolling-mixin>类和<editable-mixin>类以及其子类的优先顺序

  • 在上图中,我们比较<scrolling-mixin>类的子类<scrollabel-pane>类以及<editable-mixin>类的子类<editable-pane>类的优先级顺序。

  • 明显,由local precedence原则,<scrollabel-pane>类排在<editable-pane>类之前,因此<scrolling-mixin>类排在<editable-mixin>类之前

  • 因此,上述继承图的MRO链如下:

2.4.2 C3算法计算方法
  • 看一个例子:

  • 最后计算f(A)

2.4.3 merge函数计算方法
  1. 以上例中最后一步的merge函数计算为例

  • 首先观察merge函数中的参数,从第一个参数的一个元素B开始取,观察所有参数的后几位(第二位及以后)是否有B元素出现,若有,则不能加入结果列表中,否则就可以加入结果列表中,运算如下:

  • 接着对第一个参数的第二个元素重复上述操作,运算如下:

  • 接着对第一个参数第三个元素obj进行分析,发现其余参数的第二位及以后的元素中仍有obj元素出现,因而obj不可以加入结果列表中

  • 后面的步骤依次运算如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/175989.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Couplet | 用Python写一副对联送给诸位科研汪!~

1写在前面 小伙伴们大家新年好啊&#xff01;&#xff01;&#xff01;&#x1f970; 又是一年新春到&#xff0c;玉兔祝福要记牢&#xff1a;蹦蹦跳跳身体棒&#xff0c;平平淡淡精神爽&#xff0c;红红火火财运旺&#xff0c;和和气气朋友广&#xff0c;简简单单幸福长。&…

Windows SDK编程 初学笔记

#include "windows.h"int WINAPI WinMain(HINSTANCE hinstance, HINSTANCE hPreInstance, PSTR szCmdLine, int iCmdShow) {MessageBox(NULL, TEXT("来见见世面"), TEXT("Say Hi"), MB_OK);return 0; } MessageBox第一个参数为句柄&#xff0c;第…

Socket通信

什么是Socket?

Java基础——运算符与表达式

目录 Eclipse下载 安装 使用 运算符 键盘录入 Eclipse下载 安装 使用 Eclipse的概述(磨刀不误砍柴工)——是一个IDE(集成开发环境)Eclipse的特点描述&#xff08;1&#xff09;免费 &#xff08;2&#xff09;纯Java语言编写 &#xff08;3&#xff09;免安装 &#xff08…

【手把手教你学51单片机】

注&#xff1a;本文章转载自《手把手教你学习51单片机》&#xff01;因转载需要原文链接&#xff0c;故无法选择转载&#xff01; 如若侵权&#xff0c;请联系我进行删除&#xff01;上传至网络博客目的为了记录自己学习的过程的同时&#xff0c;同时能够帮助其他一同学习的小伙…

AJAX Axios 总结

AJAX & Axios1. AJAX1.1 作用①与服务器进行数据交换②异步交互异步和同步1.2 基本使用1.3 案例SelectUserServlet&#xff1a;register.html&#xff1a;register.html中的<script2. Axios异步框架2.1 基本使用2.2 案例axiosServlet&#xff1a;axios-demo.html&#x…

Elasticsearch7.8.0版本高级查询—— 聚合查询文档

目录一、初始化文档数据二、聚合查询文档2.1、概述2.2、对某个字段取最大值 max 示例2.3、对某个字段取最小值 min 示例2.4、对某个字段求和sum 示例2.5、对某个字段取平均值 avg 示例2.6、对某个字段的值进行去重之后再取总数 示例三、State 聚合查询文档3.1、概述3.2、示例一…

目标检测论文解读复现【NO.24】改进 YOLOv5s 的轨道障碍物检测模型轻量化研究

前言此前出了目标改进算法专栏&#xff0c;但是对于应用于什么场景&#xff0c;需要什么改进方法对应与自己的应用场景有效果&#xff0c;并且多少改进点能发什么水平的文章&#xff0c;为解决大家的困惑&#xff0c;此系列文章旨在给大家解读最新目标检测算法论文&#xff0c;…

Cadence PCB仿真使用Allegro PCB SI生成反射仿真报告及报告导读图文教程

🏡《Cadence 开发合集目录》   🏡《Cadence PCB 仿真宝典目录》 目录 1,概述2,生成报告3,报告导读4,总结1,概述 本文简单介绍使用Allegro PCB SI生成网络的反射性能评估的报告的方法,及反射报告要点导读。 2,生成报告 第1步,选择需要生成报告的网络,然后单击右…

(侯捷C++)1.2面向对象高级编程(上)

1.整体结构 2.三大函数&#xff1a;拷贝构造&#xff0c;拷贝赋值&#xff0c;析构 拷贝构造&#xff1a;第一次出现对象&#xff0c;使用拷贝构造进行创建&#xff0c;例如&#xff1a;String s3(s1)。拷贝赋值&#xff1a;对象已经构造&#xff0c;重新赋值&#xff0c;例如…

人工智能辅助药物发现(4)药物重定位

目录药物重定位概述药物重定位数据库表示学习基于序列的表示学习基于图的表示学习药物重定位深度学习以靶点为中心以疾病为中心药物重定位的应用药物重定位概述 新药物的研发投资巨大&#xff0c;周期漫长。从获批准的临床药物中有效识别新的适应药物在药物发现中起到重要作用…

cc123 靶场测试笔记

1.cc123 靶场介绍本靶场存在四个 flag 把下载到的虚拟机环境导入到虚拟机&#xff0c;本靶场需要把网络环境配置好。1.1.网络示意图2. 信息收集2.1.主机发现sudo netdiscover -i eth0 -r 192.168.0.0/242.2.masscan 端口扫描sudo masscan -p 1-65535 192.168.1.102 --rate10002…

Elasticsearch7.8.0版本高级查询—— 高亮查询文档

目录一、初始化文档数据二、高亮查询文档2.1、概述2.2、示例一、初始化文档数据 在 Postman 中&#xff0c;向 ES 服务器发 POST 请求 &#xff1a;http://localhost:9200/user/_doc/1&#xff0c;请求体内容为&#xff1a; { "name":"zhangsan", "ag…

<Python的文件>——《Python》

目录 1.文件 1.1 文件是什么 1.2 文件路径 1.3 文件操作 1.3.1 打开文件 1.3.2 关闭文件 1.3.3 写文件 1.3.4 读文件 1.3.5 关于中文的处理 1.4 使用上下文管理器 1.文件 1.1 文件是什么 变量是把数据保存到内存中. 如果程序重启/主机重启, 内存中的数据就会丢失.…

23种设计模式(十八)——组合模式【数据结构】

文章目录 意图什么时候使用组合真实世界类比组合模式的实现组合模式的优缺点亦称: 对象树、Object Tree、Composite 意图 有时又叫作整体-部分(Part-Whole)模式,是一种将对象组合成树状的层次结构的模式,用来表示“整体-部分”的关系,使用户对单个对象和组合对象具有一致…

【并发编程】Executor线程池

一、线程 1.线程 线程是调度CPU资源的最小单位。java线程与OS线程保持1:1映射关系&#xff0c;也就是说&#xff0c;一个Java线程也会在操作系统里有一个对应线程。 2.线程的生命周期 NEW,新建 RUNNABLE,运行 BLOCKED,阻塞 WAITING,等待 TIMED_WAITING,超时等待 TERMINATED…

超级完整的 Git 下载、安装与配置

Git的下载、安装与配置 一、git下载安装 1、访问git官方下载网址&#xff0c;点击这里&#xff0c;然后根据自己的电脑系统&#xff0c;下载对应的安装包&#xff1a; 2、在淘宝镜像网站 下载对应的安装包&#xff1a; 注&#xff1a; 如果由于官网下载速度过于缓慢&#xff…

String 有趣简单的编程题

String 有趣简单的编程题 每博一文案 师父说: 世上没有真正的感同身受&#xff0c;也没有谁能完全做到将心比心&#xff0c;我们一路走来。 慢慢的学会了收敛情绪&#xff0c;越成熟越沉默&#xff0c;有些人&#xff0c;背负沉重的压力&#xff0c;却从来不敢说累&#xff0c…

[python刷题模板] 树的直径/换根DP

[python刷题模板] 树的直径/换根DP 一、 算法&数据结构1. 描述2. 复杂度分析3. 常见应用4. 常用优化二、 模板代码1. 单纯询问树的直径值2. 求出树的直径两端搞事情3. 换根DP求树的直径(大炮打蚊子&#xff0c;别这么做&#xff0c;只是用来帮助理解换根DP)4. 换根dp求特定…

UDS诊断系列介绍14-2F服务

本文框架1. 系列介绍1.1 2F服务概述2. 2F服务请求与应答2.1 2F服务请求2.2 2F服务正响应2.3 2F服务否定响应3. 2F诊断使用示例4. Autosar系列文章快速链接1. 系列介绍 UDS&#xff08;Unified Diagnostic Services&#xff09;协议&#xff0c;即统一的诊断服务&#xff0c;是…