huggingface的self.state与self.control来源(TrainerState与TrainerControl)

news2024/7/4 6:36:01

文章目录

  • 前言
  • 一、huggingface的trainer的self.state与self.control初始化调用
  • 二、TrainerState源码解读(self.state)
    • 1、huggingface中self.state初始化参数
    • 2、TrainerState类的Demo
  • 三、TrainerControl源码解读(self.control)
  • 总结


前言

在 Hugging Face 中,self.state 和 self.control 这两个对象分别来源于 TrainerState 和 TrainerControl,它们提供了对训练过程中状态和控制流的访问和管理。通过这些对象,用户可以在训练过程中监视和调整模型的状态,以及控制一些重要的决策点。


一、huggingface的trainer的self.state与self.control初始化调用

trainer函数初始化调用代码如下:

# 定义Trainer对象
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
   
)

在Trainer()类的初始化的self.state与self.control初始化调用,其代码如下:

class Trainer:
	def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
		...
		 self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

        self.control = TrainerControl()
		...

二、TrainerState源码解读(self.state)

1、huggingface中self.state初始化参数

这里多解读一点huggingface的self.state初始化调用参数方法,

 self.state = TrainerState(
      is_local_process_zero=self.is_local_process_zero(),
      is_world_process_zero=self.is_world_process_zero(),
  )

而TrainerState的内部参数由trainer的以下2个函数提供,可知道这里通过self.args.local_process_index与self.args.process_index的值来确定TrainerState方法的参数。

 def is_local_process_zero(self) -> bool:
     """
     Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
     machines) main process.这个过程是否是本地主进程(例如,如果在多台机器上以分布式方式进行训练,则是在一台机器上)。
     """
     return self.args.local_process_index == 0

 def is_world_process_zero(self) -> bool:
     """
     Whether or not this process is the global main process (when training in a distributed fashion on several
     machines, this is only going to be `True` for one process).这个过程是否是全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会返回True)。
     """
     # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
     # process index.
     if is_sagemaker_mp_enabled():
         return smp.rank() == 0
     else:
         return self.args.process_index == 0

self.args.local_process_index与self.args.process_index来源self.args

2、TrainerState类的Demo

介于研究state,我写了一个Demo来探讨使用方法,class TrainerState来源huggingface。该类实际就是一个存储变量的方式,变量包含epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0等内容,也进行了默认参数赋值,其Demo如下:

from dataclasses import dataclass
import dataclasses
import json
from typing import Dict, List, Optional, Union
@dataclass
class TrainerState:
    epoch: Optional[float] = None
    global_step: int = 0
    max_steps: int = 0
    num_train_epochs: int = 0
    total_flos: float = 0
    log_history: List[Dict[str, float]] = None
    best_metric: Optional[float] = None
    best_model_checkpoint: Optional[str] = None
    is_local_process_zero: bool = True
    is_world_process_zero: bool = True
    is_hyper_param_search: bool = False
    trial_name: str = None
    trial_params: Dict[str, Union[str, float, int, bool]] = None

    def __post_init__(self):
        if self.log_history is None:
            self.log_history = []

    def save_to_json(self, json_path: str):
        """Save the content of this instance in JSON format inside `json_path`."""
        json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
        with open(json_path, "w", encoding="utf-8") as f:
            f.write(json_string)

    @classmethod
    def load_from_json(cls, json_path: str):
        """Create an instance from the content of `json_path`."""
        with open(json_path, "r", encoding="utf-8") as f:
            text = f.read()
        return cls(**json.loads(text))

if __name__ == '__main__':

    state = TrainerState()
    state.save_to_json('state.json')
    state_new = state.load_from_json('state.json')

我这里使用state = TrainerState()方法对TrainerState()类实例化,使用state.save_to_json('state.json')进行json文件保存(如下图),若修改里面参数,使用state_new = state.load_from_json('state.json')方式载入会得到新的state_new实例化。
在这里插入图片描述

三、TrainerControl源码解读(self.control)

该类实际就是一个存储变量的方式,变量包含 should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False内容,也进行了默认参数赋值,其源码如下:

@dataclass
class TrainerControl:
    """
    A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
    switches in the training loop.

    Args:
        should_training_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the training should be interrupted.

            If `True`, this variable will not be set back to `False`. The training will just stop.
        should_epoch_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the current epoch should be interrupted.

            If `True`, this variable will be set back to `False` at the beginning of the next epoch.
        should_save (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be saved at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_evaluate (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be evaluated at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_log (`bool`, *optional*, defaults to `False`):
            Whether or not the logs should be reported at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
    """

    should_training_stop: bool = False
    should_epoch_stop: bool = False
    should_save: bool = False
    should_evaluate: bool = False
    should_log: bool = False

    def _new_training(self):
        """Internal method that resets the variable for a new training."""
        self.should_training_stop = False

    def _new_epoch(self):
        """Internal method that resets the variable for a new epoch."""
        self.should_epoch_stop = False

    def _new_step(self):
        """Internal method that resets the variable for a new step."""
        self.should_save = False
        self.should_evaluate = False
        self.should_log = False

总结

本文主要介绍huggingface的trainer中的self.control与self.state的来源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1715826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣:226. 翻转二叉树

226. 翻转二叉树 已解答 简单 相关标签 相关企业 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 示例 1: 输入:root [4,2,7,1,3,6,9] 输出:[4,7,2,9,6,3,1]示例 2: 输入&#xff1a…

数组-检查数组内是否存在和为7的倍数的子序列

一、题目描述 二、解题思路 这里首先要分辨清楚是子序列还是子数组 原数组:[1,2,3,4,5] 子序列:元素和元素之间相对位置保持不变,但是在原数组中不一定连续,如:[1,3,4]; 子数组:元素元素之间保…

canfd与can2.0关系

canfd是can2.0的升级版, 支持canfd的设备就支持can2.0,但can2.0的设备不支持canfd 参考 是选CAN接口卡还是CANFD接口卡_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1Hh411K7Zn/?spm_id_from333.999.0.0 哪些STM32有CANFD外设 STM32G0, STM…

一款免费的软件媒体系统软件!!【送源码】

Jellyfin是一个免费的软件媒体系统,让您在管理和流媒体控制您的媒体。它是专有的Emby和Plex的替代品,通过多个应用程序从专用服务器向最终用户设备提供媒体。Jellvfin是Emby的3.5.2版本的后裔,并被移植到.NETCore框架中,以实现完全…

新火种AI|寻求合作伙伴,展开豪赌,推出神秘AI项目...苹果能否突破AI困境?

作者:小岩 编辑:彩云 2024年,伴随着AI技术的多次爆火,不仅各大科技巨头纷纷进入AI赛道展开角力,诸多智能手机厂商也纷纷加紧布局相关技术,推出众多AI手机。作为手机领域的龙头老大,苹果自然是…

基于单片机的步进电机控制系统研究

摘 要 : 近年来 , 步进电机凭借其定位精度高 、 使用方便 、 性价比高 、 容易控制等优点 , 在各领域受到广泛应用 。 文中利用C52 单片机设计了一种步进电机控制系统 , 介绍了其总体方案 、 主控制模块 、 驱动电路 、 键盘 、 晶…

洗地机哪个牌子最好用?十大名牌洗地机排行榜

作为一种新兴的智能家居产品,洗地机的市场规模已经突破了百亿大关。如此庞大的市场自然吸引了大量资本的涌入,许多品牌纷纷推出自己的洗地机产品,试图在这个竞争激烈的市场中占据一席之地。然而,面对如此多的品牌和型号&#xff0…

SelfKG论文翻译

SelfKG: Self-Supervised Entity Alignment in Knowledge Graphs SelfKG:知识图中的自监督实体对齐 ABSTRACT 实体对齐旨在识别不同知识图谱(KG)中的等效实体,是构建网络规模知识图谱的基本问题。在其发展过程中,标…

Java面试题分享-敏感词替换 java 版本

入职啦最近更新了一些后端笔试、面试题目,大家看看能快速实现吗? 关注 入职啦 微信公众号,每日更新有用的知识,Python,Java,Golang,Rust,javascript 等语言都有 不要再用replaceAll做…

Django之文件上传(一)

一、环境搭建 建立项目 django-admin startproject project_demo配置数据库(以MySQL为例) # settings.py DATABASES = {default: {ENGINE: django.db.backends.mysql,NAME: django_file4,USER: root,PASSWORD: 123,HOST: 192.168.31.151,PORT: 3306,} }建立模型 class UploadF…

Vue 3 教程:核心知识

Vue 3 教程:核心知识 1. Vue3简介1.1. 【性能的提升】1.2.【 源码的升级】1.3. 【拥抱TypeScript】1.4. 【新的特性】 2. 创建Vue3工程2.1. 【基于 vue-cli 创建】2.2. 【基于 vite 创建】(推荐)2.3. 【一个简单的效果】 3. Vue3核心语法3.1. 【OptionsAPI 与 Compo…

Codeforces Round 948 (Div. 2) E. Tensor(思维题-交互)

题目 n(3<n<100)个点的有向图&#xff0c; 图的边的关系未知&#xff0c;但保证以下两点&#xff1a; 1. 只存在j->i&#xff08;i<j&#xff09;的边 2. 对于任意三个点i、j、k&#xff08;i<j<k&#xff09;&#xff0c;要么k可以到达i&#xff0c;要么…

基于java实现图片中任意封闭区域识别

需求&#xff1a; 在浏览器中给用户呈现一张图片&#xff0c;用户点击图片中的某些标志物&#xff0c;需要系统给出标志物的信息反馈&#xff0c;达到一个交互的作用。 比如下图中&#xff0c;点击某个封闭区域时候&#xff0c;需要告知用户点击的区域名称及图形形状特性等等。…

Django之rest_framework(九)

一、分页-PageNumberPagination类 REST framework提供了分页的支持 官网:Pagination - Django REST framework 1.1、全局设置 # settings.py REST_FRAMEWORK = {DEFAULT_PAGINATION_CLASS: rest_framework.pagination.PageNumberPagination,PAGE_SIZE: 100 # 每页数目 }提示…

相对论表明速度越快时间越慢,为什么速度会影响时间?

在物理学的宏伟殿堂中&#xff0c;相对论以其深邃的洞察力&#xff0c;挑战了我们对时间和空间的传统认识。1905年&#xff0c;阿尔伯特爱因斯坦提出了狭义相对论&#xff0c;揭示了在所有惯性参照系中&#xff0c;光速是常数的惊人事实。 随后在1915年&#xff0c;他进一步发展…

ABAP 在增强中COMMIT

前言 呃&#xff0c;又是很磨人的需求&#xff0c;正常情况下是不允许在增强中COMMIT的&#xff0c;会影响源程序本身的逻辑&#xff0c;但是这个需求就得这么干… 就是在交货单增强里面要再调用一次交货单BAPI&#xff0c;通过SO的交货单自动创建STO的交货单&#xff0c;如果…

uniapp h5项目切换导航栏及动态渲染按钮颜色

1.效果图 2.html,动态渲染按钮样式---三元判断 <!-- 切换栏 --><view class"statusList"><block v-for"(item,index) in list" :key"index"><view class"swiper-tab-list" :class"current item.id?activ…

基于扩散模型的,开源世界模型DIAMOND

日内瓦大学、微软研究院和爱丁堡大学的研究人员联合开源了&#xff0c;基于扩散模型的世界模型—DIAMOND。 研究人员之所以选择扩散模型作为基础&#xff0c;是因为可以更好地捕捉视觉细节&#xff0c;同时具有建模复杂多模态分布的能力&#xff0c;以便在不同的环境下进行训练…

网络融合的力量:企业如何通过“一网多用”提升业务效率

随着企业业务的不断扩展&#xff0c;网络需求变得日益复杂。需要的是一种能够统一承载办公、生产、销售和运营等多业务需求的网络架构。这种“一网多用”的架构&#xff0c;不仅简化了网络部署和管理&#xff0c;还提升了效率并降低了成本。 “一网多用”架构的实际应用&#x…

[ C++ ] 深入理解模板( 初 阶 )

函数模板 函数模板格式 template <typename T1, typename T2,......,typename Tn> 返回值类型 函数名(参数列表){} 注意&#xff1a; typename是用来定义模板参数关键字&#xff0c;也可以使用class(切记&#xff1a;不能使用struct代替class) 函数模板的实例化 模板参数…