手撕Pytorch源码#2.Dataset类 part2

news2024/12/25 1:54:57

写在前面

  1. 手撕Pytorch源码系列目的:

  • 通过手撕源码复习+了解高级python语法

  • 熟悉对pytorch框架的掌握

  • 在每一类完成源码分析后,会与常规深度学习训练脚本进行对照

  • 本系列预计先手撕python层源码,再进一步手撕c源码

  1. 版本信息

python:3.6.13

pytorch:1.10.2

  1. 本博文涉及python语法点

  • @staticmethod修饰器

  • super类的全新理解【大概率有你闻所未闻的华点!】

  • bisect二分法搜索方法

  • @property修饰器

目录

[TOC]

零、流程图

一、ConcatDataset类

  1. 在上篇博文中,Dataset类的__add__方法中使用了ConcatDataset类,因而本篇对其进行研究,并学习相关的python语法点

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
  • Dataset[T_co],Type hint用法见上一篇博文

1.0 源代码
  • 注:在第一部分未进行精讲的代码已在下方源代码处做了详细的注释!

class ConcatDataset(Dataset[T_co]):
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ConcatDataset, self).__init__()
        # 将Iterable[Dataset]转化成为List(Dataset)
        self.datasets = list(datasets)
        # 对传入的datasets序列的长度进行合法性判断
        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        # 仅接收Dataset类型,不接收IterableDataset类型
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        # cumulative_size的最后一个元素就是整个数据集序列的总长度
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        # 考虑的是用负数进行索引的情况
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            # 对负数的索引值进行正向化,只需要用总长度+负数索引值即可
            # 例如:对最后一个元素的负数索引为-1,总长度为n,则n-1恰好为正向索引值
            # len(self)相当于调用__len__方法,获取总长度
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        # 如果判定就是第一个数据集,那么对数据集的索引就是总的索引
        if dataset_idx == 0:
            sample_idx = idx
        # 如果判定不是第一个数据集,那么对数据集的索引就不是总的索引
        # 例如,计算出来的cumulative_size为[100,300,600,1000],而总索引为569
        # 那么可以判定出来索引位于第三个数据集上(数据集大小为300)
        # 那么在第三个数据集上的索引应该是569-300=269
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes
1.1 @staticmethod以及cumsum()函数
  • @符号表示修饰器,而staticmethod则声明了该函数为类的静态方法,具体解释见【2.1节@staticmethod】

  • cumsum(sequence)用于计算每个数据集累计长度的序列,为了方便__getitem__方法中,通过索引的序列号index定位序列号位于哪个数据集

  • 假定传入的数据集序列为datasets = [ds1,ds2,ds3,ds4]共有四个数据集,其长度分别为[100,200,300,400],则cumcum(sequence)函数则会生成列表cumulative = [100,300,600,1000]。如果__getitem__方法传入的序列号为596,则可以用cumulative列表判断该序号的数据属于第三个数据集,进而可以从ds3中取得相应的数据

1.2 super(ConcatDataset, self).__init__()
  • 上述语句的作用为调用ConcatDataset类的父类Dataset的初始化函数

  • super类的精讲见【2.2节super类】,关于其中涉及的mro链的精讲见下一篇博文

1.3 dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  • bisect_right函数的精讲见【2.3节bisect二分法搜索】

  • 源代码利用bisect_right通过二分法找到目标索引值index在前文所述cumsum(sequence)函数生成的列表中的位置,并且通过该位置即可判定该索引位于数据集序列的第几个数据集中,从而计算出dataset_idx

1.4 @properpy以及cummulative_sizes(self)
  1. @properpy用于是函数可以像属性一样被直接调用,精讲见【2.4节@property】

  1. cummulative_sizes(self)函数其实为了版本兼容,通过warning.warn提示用户此方法已经改名,并将正确的值通过函数传递给该方法

  1. warnings.warn(message, category=None, stacklevel=1, source=None)函数用于抛出警告,而上述代码中的DeprecationWarning是代码被弃用的警告

二、相应的Python语法补充

2.1 @staticmethod
  • @符号表示修饰器,而staticmethod则声明了该函数为类的静态方法,静态方法直接属于类,不需要传入self参数,正如外部函数一样进行定义,并且可以直接通过类或对象进行调用

class Static():
    def __init__(self,x:int)->None:
        self.x = x
    @staticmethod
    def static()->None:
        print("Trying static")

Static.static()
s = Static(1)
s.static()

# 输出
# Trying static
# Trying static
2.2 super类
  1. 冷知识:super()其实并不是一个函数,更不是一个关键字,而是一个类,因此在程序中使用super()是创建了一个对象,而super的原型则是super(type,type or object)

  1. super类传入的两个参数分别代表什么:

  • 第一个type为一个类名,如上代码中的ConcatDataset,而第二个object则往往是一个M object,决定一个mro chain【mro链精讲见下一篇博文】

  • 而第一个type决定了在mro chain中的取值位置,即从第一个type类下一个类开始,在第二个object决定的mro链上寻找最近的函数进行调用

  • 当然,光说很抽象,直接上代码

from objprint import op
class Animal():
    def __init__(self,age):
        self.age = age
    
class Person(Animal):
    def __init__(self, age,name):
        super(Person,self).__init__(age)
        self.name = name

class Male(Person):
    def __init__(self,age,name):
        super(Person,self).__init__(age)
        self.gender = 'male'

class Female(Person):
    def __init__(self,age,name):
        super(Female,self).__init__(age,name)
        self.gender = 'female'

m = Male(50,"Tim")
fm = Female(40,"Lily")

op(m)
op(fm)

# 输出结果:
# <Male 0x166c47a5d90
#   .age = 50,
#   .gender = 'male'
# >
# <Female 0x166c47a5c10
#   .age = 40,
#   .gender = 'female',
#   .name = 'Lily'
# >
  • 根据上述代码结果可以发现:Male类与Female中super()的第二个参数均为各自类别的self,因此Male类与Female类对应的mro链为Male:Male->Person->Animal,Female:Female->Person->Animal

  • Male:而在Male类中super()传入的第一个参数为Person因此从mro链中Person的下一个类,即Animal类开始查找__init__初始化函数,因此最终Male类的初始化函数相当于仅调用了Animal.__init__函数,因而无法对name属性进行初始化定义

  • Female:而Female类中super()传入的第一个参数为Female因此从mro链中Female的下一个类,即Person类开始查找__init__初始化函数,因此首先调用Person.__init__对name属性进行初始化

  • Female:接着,对于Person类,其super()的第二个参数为Person类比的self,因此其mro链为Person:Person->Animal,故从Animal类开始寻找__init__函数进行初始化。最终Female类的初始化函数相当于同时调用了Person.__init__函数以及Animal.__init__函数,因而对name和age属性都进行了初始化定义

  1. super类可能存在的黑魔法

  • 首先上代码看例子

class A():
    def display(self):
        print("A")

class B(A):
    def display(self):
        super(B,self).display()

class C(A):
    def display(self):
        print("C")

class D(B,C):
    def display(self):
        super(B,self).display()

b = B()
d = D()
b.display()
d.display()

# 输出结果为:
# A
# C
  • 或者对上述代码中的D类换一种写法

    class D(B,C):
        def display(self):
            B.display(self)

    d = D()
    d.display()
# 输出结果为:
# C
  • 那么为什么调用B.display函数最后会输出C,而C则和B没有任何的继承与被继承关系呢?

  • 首先,由于D继承了B和C两个类,因此其mro链为D:D->B->C->A,而super()类输入的第一个参数为B,即从B的下一个类开始寻找display()函数,进而调用C的display()函数最终输出了C

  • 因此,虽然B与C类没有任何的继承与被继承关系,但两类通过D类的mro链被联系在一起

2.3 bisect二分法搜索
  1. python的bisect库有三种二分法搜索的API分别是:bisect.bisect(Sequence,x),bisect.bisect_left(Sequence,x),bisect.bisect_right(Sequence,x)

  • 其中,在二分法搜索上bisect.bisect(Sequence,x)与bisect.bisect_right(Sequence,x)是完全等效的

  • bisect.bisect(Sequence,x)与bisect.bisect_right(Sequence,x)是找到传入值x的最小下标值

  • 而bisect.bisect_left(Sequence,x)则是找到传入值x的最小下标值

  1. 话不多说,直接上代码:

import bisect
l = [100,255,512,1036]
print(bisect.bisect(l,500))
print(bisect.bisect_right(l,500))
print(bisect.bisect_left(l,500))

print(bisect.bisect(l,512))
print(bisect.bisect_right(l,512))
print(bisect.bisect_left(l,512))

# 输出值为:
# 2
# 2
# 2
# 3
# 3
# 2
2.4 @property
  1. @property目的是可以让函数像属性一样被调用

  1. 而将属性变化成函数则是可以避免类调用时,对象的属性被非法修改

  1. 如一个长方形类,其属性为长,宽与面积,而如果将面积作为属性进行定义,那么在类外,可以直接利用self.area对面积值进行修改,从而使得长宽与面积值的不匹配,而用@property修饰器则可以解决该问题

  1. 代码示例如下:

class Property():
    def __init__(self,height,width) -> None:
        self.height = height
        self.width = width

    @property
    def area(self):
        return self.height*self.width


p = Property(2,4)
print(p.area)

# 输出结果为:
# 8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/170943.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

十四.文件操作

目录 一.为什么使用文件 二.什么是文件 1.程序文件和数据文件 2.文件名 三.文件的打开和关闭 1.文件指针 2.fopen函数和fclose函数 四.文件的顺序读写 1.顺序读写函数一览表 2.主要输入输出函数介绍 &#xff08;1&#xff09;字符输出函数futc &#xff08;2&…

Python采集*瓣电影影评并实现可视化分析

前言 嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 又到了学Python时刻~ 环境使用: Python 3.8 解释器 Pycharm 编辑器 模块使用 import parsel >>> pip install parsel import requests >>> pip install requests import csv 安装python第三方…

LeetCode 1825 求出MK平均值【Set 队列】 HERODING的LeetCode之路

解题思路&#xff1a; 好久没更新力扣困难题的题解了&#xff0c;今天这道困难题有点意思&#xff0c;读罢题目一目了然&#xff0c;解题思路清晰明了&#xff0c;就是解题过程细节满满。这是一个数据流场景的问题&#xff0c;保留最后m个元素&#xff0c;但是要去除k个最大&am…

设计模式—工厂方法模式

工厂方法模式 文章目录工厂方法模式工厂方法模式是什么理解工厂方法模式代码实例运行截图工厂方法的优点工厂方法的不足工厂方法模式是什么 工厂方法模式属于创建型模式&#xff0c;也叫抽象构造模式&#xff0c; 工厂方法模式将工厂抽象化&#xff0c;并定义一个创建对象的接…

高级语言(C语言)、汇编语言、机器语言区别?编译器如何将高级语言编译成机器语言?

⾼级语⾔&#xff1a; 是相对于汇编语⾔⽽⾔的&#xff0c;是⾼度封装了的编程语⾔&#xff0c;与低级语⾔相对。它是以⼈类的⽇常语⾔为基础的⼀种编程语⾔&#xff0c;使⽤⼀般⼈易于接受的⽂字来表⽰&#xff08;例如汉字、不规则英⽂或其他外语&#xff09;&#xff0c;从…

(二十四)List系列集合

目录 前言: 一、List集合的特有方法 二、List集合的遍历方式有几种&#xff1f; 三、Arraylist集合底层原理 四、LinkedList的特点 前言: List集合包括JavaList接口以及List接口的所有实现类。List集合中的元素允许重复&#xff0c;各元素的顺序放是对象插入的顺序&#xff…

C生万物 | C语言文件操作指南汇总【内附文件外排序源码】

&#x1f451;作者主页&#xff1a;Fire_Cloud_1 &#x1f3e0;学习社区&#xff1a;烈火神盾 &#x1f517;专栏链接&#xff1a;万物之源——C 文章目录一、为什么使用文件&#xff1f;二、什么是文件&#xff1f;1、程序文件2、数据文件3、文件名三、文件的打开和关闭1、文件…

自动化测试【软件测试】

自动化测试 什么是自动化 有效减少人力的消耗&#xff0c;同时提高生活的质量 通过自动化测试有效减少人力的投入&#xff0c;同时提高了测试的质量和效率 由于回归测试&#xff0c;版本越来越多&#xff0c;版本回归的压力越来越大&#xff0c;仅仅通过人工测试来回归所有版本…

2.3、进程控制

整体框架 1、什么是进程控制&#xff1f; 进程控制的主要功能是对系统中的所有进程实施有效的管理&#xff0c; 它具有创建新进程、撤销已有进程、实现进程状态转换等功能。 简单来说&#xff1a;进程控制就是要实现进程状态转换 2、如何实现进程控制&#xff1f; 2.1、进程…

ATAC-seq分析:TSS 信号(7)

ATACseq ATACseq - 使用转座酶并提供一种同时从单个样本的转录因子结合位点和核小体位置提取信号的方法。 1. 数据类型 上面这意味着我们的数据中可能包含多种信号类型。 我们将从无核小体区域和转录因子&#xff08;我们的较短片段&#xff09;周围获得信号。我们的一部分信号…

2-Spring核心与设计思想

目录 1.Spring是什么&#xff1f; 2.容器是什么&#xff1f; 3.IoC是什么&#xff1f; 3.1.传统程序开发 3.2.控制反转式程序开发 3.3.对比总结规律 4.理解Spring IoC 4.1.将对象(Bean)存入到容器(Spring)&#xff1b; 4.2.从容器中取出对象。 5.DI概念说明 1.Spring…

计算机编程背景

&#x1f496; 欢迎来阅读子豪的博客&#xff08;JavaEE篇 &#x1f934;&#xff09; &#x1f449; 有宝贵的意见或建议可以在留言区留言 &#x1f4bb; 欢迎 素质三连 点赞 关注 收藏 &#x1f9d1;‍&#x1f680;码云仓库&#xff1a;补集王子的代码仓库 不要偷走我小火…

classpath类路径是什么

Spring Boot 一、简介 classpath类路径在 Spring Boot 中既指程序在打包前的/java/目录加上/resource目录&#xff0c;也指程序在打包后生成的/classes/目录。两者实际上指的是同一个目录&#xff0c;里面包含的文件内容一模一样。 二、获取classpath路径 以下两种方式均可…

Icarus Verilog

Icarus Verilog 是一个Verilog仿真工具&#xff0c;以编译器的形式工作&#xff0c;将以verilog编写的源代码编译为某种目标格式。如果要进行仿真的话&#xff0c;可以生成一个vvp的中间格式&#xff0c;由其所附带的vvp命令执行。 https://github.com/steveicarus/iverilog …

面试官:请设计一个能支撑百万连接的系统架构!

目录 1、到底什么是连接&#xff1f;2、为什么每次发送请求都要建立连接&#xff1f;3、长连接模式下需要耗费大量资源4、Kafka遇到的问题&#xff1a;应对大量客户端连接5、Kafka的架构实践&#xff1a;Reactor多路复用6、优化后的架构是如何支撑大量连接的 这篇文章&#x…

SQL Server 全文索引的应用

在公司项目中提出了一个需求&#xff1a; 搜索包含指定关键词的数据。得到这需求后&#xff0c;站在技术角度考虑第一时间就联想到使用SQL里面“like”查询语句。进一步分析需求后&#xff0c;发现“Like”查询满足不到实际的要求。 示例&#xff1a; ---------------------…

【Ajax】接口与接口测试工具PostMan

一、接口接口的概念使用 Ajax 请求数据时&#xff0c;被请求的 URL 地址&#xff0c;就叫做数据接口&#xff08;简称接口&#xff09;。同时&#xff0c;每个接口必须有请求方式。例如&#xff1a;http://www.liulongbin.top:3006/api/getbooks 获取图书列表的接口(GET请求)ht…

【4 - 降维算法PCA和SVD - 案例部分】菜菜sklearn机器学习

课程地址&#xff1a;《菜菜的机器学习sklearn课堂》_哔哩哔哩_bilibili 第一期&#xff1a;sklearn入门 & 决策树在sklearn中的实现第二期&#xff1a;随机森林在sklearn中的实现第三期&#xff1a;sklearn中的数据预处理和特征工程第四期&#xff1a;sklearn中的降维算法…

为何香港的IB状元特别多?

今年IB预科课程&#xff08;The International Baccalaureate Diploma Programme&#xff0c;IBDP&#xff09;公开考试放榜&#xff0c;香港的学校又是大丰收的一年&#xff01;因为香港今年一共有九十三名IB状元&#xff0c;即IB的总分为四十五分满分&#xff0c;而他们全部取…

Linux 环境部署 Nexus 服务

一 私服是什么&#xff1f; 一个特殊的远程仓库&#xff0c;它是架设在局域网内的仓库服务&#xff0c;供局域网内的开发人员使用。 当Maven需要下载构建的使用&#xff0c; 它先从私服请求&#xff0c;如果私服上没有的话&#xff0c;则从外部的远程仓库下载&#xff0c;然后…