MMaction2 使用记录1——config介绍

news2025/1/11 5:57:09

 

目录

了解config (模型训练测试的整体过程配置文件)

通过脚本参数修改config

Config 文件 结构

config文件的命名规则

动作识别的config系统 


了解config (模型训练测试的整体过程配置文件)

我们使用python文件作为config,将模块化和继承性设计融入到我们的config系统中,方便进行各种实验。可以在MMAction2/configs下找到所有提供的config。如果你想检查、分析config文件,你可以运行python tools/analysis_tools/print_config.py /PATH/TO/CONFIG来查看完整的config。

通过脚本参数修改config

当使用tools/train.py或tools/test.py提交jobs时,你可以指定--cfg-options来就地修改config。

更新config的字典关键字。

config选项可以按照原始配置中dict的keys顺序来指定。例如,--cfg-options model.backbone.norm_eval=False将模型骨干中的所有BN模块改为训练模式。

更新configs list内的keys。

一些config字典在你的config中被组成一个list。例如,训练pipeline train_pipeline 通常是一个list,例如 [dict(type='SampleFrames'), ...]。如果你想把pipeline中的 "SampleFrames "改为 "DenseSampleFrames",你可以指定--cfg-options train_pipeline.0.type=DenseSampleFrames。

更新list/tuples的值。

如果要更新的值是一个列表或一个元组。例如,config文件通常设置model.data_preprocessor.mean=[123.675, 116.28, 103.53]。如果你想改变这个键,你可以指定--cfg-options model.data_preprocessor.mean="[128,128,128]"。注意,引号 "是支持列表/元组数据类型。

Config 文件 结构

configs/_base_下有3种基本的组件类型,models、schedules、default_runtime。许多方法可以很容易地用其中的一个来构建,比如TSN、I3D、SlowOnly等等。由来自_base_的组件组成的配置被称为primitive。

对于同一文件夹下的所有configs,建议只拥有一个primitive的config。所有其他的config都应该继承自primitive的config。最大的继承级别是3。

为了便于理解,我们建议贡献者从退出的方法中继承。例如,如果在TSN的基础上做了一些修改,用户可以首先通过指定_base_ = ../tsn/tsn_imagenet-retrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py来继承基本的TSN结构,然后修改配置文件中的必要字段。

如果你正在构建一个全新的方法,并且不与任何现有的方法共享结构,你可以在configs/TASK下创建一个文件夹。

这部分具体请参考mmengine的详细文档。

config文件的命名规则

我们遵循下面的风格来命名config文件。建议贡献者遵循同样的风格。config文件的名称被分为几个部分。从逻辑上讲,不同的部分用下划线"_"连接,而同一部分的设置用破折号"-"连接。

{algorithm info}_{module info}_{training info}_{data info}.py

动作识别的config系统 

我们将模块化设计纳入我们的config系统,这便于进行各种实验。

一个TSN模型的例子

为了帮助用户对一个完整的config结构和动作识别系统中的模块有一个基本的概念,我们对TSN的config做了如下的简要介绍。关于每个模块中每个参数的更详细的用法和选择,请参考API文档。

# model setting
model = dict(  # Config of the model
    type='FastRCNN',  # Class name of the detector
    _scope_='mmdet',  # The scope of current config
    backbone=dict(  # Dict for backbone
        type='ResNet3dSlowOnly',  # Name of the backbone
        depth=50, # Depth of ResNet model
        pretrained=None,   # The url/site of the pretrained model
        pretrained2d=False, # If the pretrained model is 2D
        lateral=False,  # If the backbone is with lateral connections
        num_stages=4, # Stages of ResNet model
        conv1_kernel=(1, 7, 7), # Conv1 kernel size
        conv1_stride_t=1, # Conv1 temporal stride
        pool1_stride_t=1, # Pool1 temporal stride
        spatial_strides=(1, 2, 2, 1)),  # The spatial stride for each ResNet stage
    roi_head=dict(  # Dict for roi_head
        type='AVARoIHead',  # Name of the roi_head
        bbox_roi_extractor=dict(  # Dict for bbox_roi_extractor
            type='SingleRoIExtractor3D',  # Name of the bbox_roi_extractor
            roi_layer_type='RoIAlign',  # Type of the RoI op
            output_size=8,  # Output feature size of the RoI op
            with_temporal_pool=True), # If temporal dim is pooled
        bbox_head=dict( # Dict for bbox_head
            type='BBoxHeadAVA', # Name of the bbox_head
            in_channels=2048, # Number of channels of the input feature
            num_classes=81, # Number of action classes + 1
            multilabel=True,  # If the dataset is multilabel
            dropout_ratio=0.5),  # The dropout ratio used
    data_preprocessor=dict(  # Dict for data preprocessor
        type='ActionDataPreprocessor',  # Name of data preprocessor
        mean=[123.675, 116.28, 103.53],  # Mean values of different channels to normalize
        std=[58.395, 57.12, 57.375],  # Std values of different channels to normalize
        format_shape='NCHW')),  # Final image shape format
    # model training and testing settings
    train_cfg=dict(  # Training config of FastRCNN
        rcnn=dict(  # Dict for rcnn training config
            assigner=dict(  # Dict for assigner
                type='MaxIoUAssignerAVA', # Name of the assigner
                pos_iou_thr=0.9,  # IoU threshold for positive examples, > pos_iou_thr -> positive
                neg_iou_thr=0.9,  # IoU threshold for negative examples, < neg_iou_thr -> negative
                min_pos_iou=0.9), # Minimum acceptable IoU for positive examples
            sampler=dict( # Dict for sample
                type='RandomSampler', # Name of the sampler
                num=32, # Batch Size of the sampler
                pos_fraction=1, # Positive bbox fraction of the sampler
                neg_pos_ub=-1,  # Upper bound of the ratio of num negative to num positive
                add_gt_as_proposals=True), # Add gt bboxes as proposals
            pos_weight=1.0)),  # Loss weight of positive examples
    test_cfg=dict(rcnn=None))  # Testing config of FastRCNN

# dataset settings
dataset_type = 'AVADataset' # Type of dataset for training, validation and testing
data_root = 'data/ava/rawframes'  # Root path to data
anno_root = 'data/ava/annotations'  # Root path to annotations

ann_file_train = f'{anno_root}/ava_train_v2.1.csv'  # Path to the annotation file for training
ann_file_val = f'{anno_root}/ava_val_v2.1.csv'  # Path to the annotation file for validation

exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv'  # Path to the exclude annotation file for training
exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv'  # Path to the exclude annotation file for validation

label_file = f'{anno_root}/ava_action_list_v2.1_for_activitynet_2018.pbtxt'  # Path to the label file

proposal_file_train = f'{anno_root}/ava_dense_proposals_train.FAIR.recall_93.9.pkl'  # Path to the human detection proposals for training examples
proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl'  # Path to the human detection proposals for validation examples

train_pipeline = [  # Training data processing pipeline
    dict(  # Config of SampleFrames
        type='AVASampleFrames',  # Sample frames pipeline, sampling frames from video
        clip_len=4,  # Frames of each sampled output clip
        frame_interval=16),  # Temporal interval of adjacent sampled frames
    dict(  # Config of RawFrameDecode
        type='RawFrameDecode'),  # Load and decode Frames pipeline, picking raw frames with given indices
    dict(  # Config of RandomRescale
        type='RandomRescale',   # Randomly rescale the shortedge by a given range
        scale_range=(256, 320)),   # The shortedge size range of RandomRescale
    dict(  # Config of RandomCrop
        type='RandomCrop',   # Randomly crop a patch with the given size
        size=256),   # The size of the cropped patch
    dict(  # Config of Flip
        type='Flip',  # Flip Pipeline
        flip_ratio=0.5),  # Probability of implementing flip
    dict(  # Config of FormatShape
        type='FormatShape',  # Format shape pipeline, Format final image shape to the given input_format
        input_format='NCTHW',  # Final image shape format
        collapse=True),   # Collapse the dim N if N == 1
    dict(type='PackActionInputs') # Pack input data
]

val_pipeline = [  # Validation data processing pipeline
    dict(  # Config of SampleFrames
        type='AVASampleFrames',  # Sample frames pipeline, sampling frames from video
        clip_len=4,  # Frames of each sampled output clip
        frame_interval=16),  # Temporal interval of adjacent sampled frames
    dict(  # Config of RawFrameDecode
        type='RawFrameDecode'),  # Load and decode Frames pipeline, picking raw frames with given indices
    dict(  # Config of Resize
        type='Resize',  # Resize pipeline
        scale=(-1, 256)),  # The scale to resize images
    dict(  # Config of FormatShape
        type='FormatShape',  # Format shape pipeline, Format final image shape to the given input_format
        input_format='NCTHW',  # Final image shape format
        collapse=True),   # Collapse the dim N if N == 1
    dict(type='PackActionInputs') # Pack input data
]

train_dataloader = dict(  # Config of train dataloader
    batch_size=32,  # Batch size of each single GPU during training
    num_workers=8,  # Workers to pre-fetch data for each single GPU during training
    persistent_workers=True,  # If `True`, the dataloader will not shut down the worker processes after an epoch end, which can accelerate training speed
    sampler=dict(
        type='DefaultSampler',  # DefaultSampler which supports both distributed and non-distributed training. Refer to https://github.com/open-mmlab/mmengine/blob/main/mmengine/dataset/sampler.py
        shuffle=True),  # Randomly shuffle the training data in each epoch
    dataset=dict(  # Config of train dataset
        type=dataset_type,
        ann_file=ann_file_train,  # Path of annotation file
        exclude_file=exclude_file_train,  # Path of exclude annotation file
        label_file=label_file,  # Path of label file
        data_prefix=dict(img=data_root),  # Prefix of frame path
        proposal_file=proposal_file_train,  # Path of human detection proposals
        pipeline=train_pipeline))
val_dataloader = dict(  # Config of validation dataloader
    batch_size=1,  # Batch size of each single GPU during evaluation
    num_workers=8,  # Workers to pre-fetch data for each single GPU during evaluation
    persistent_workers=True,  # If `True`, the dataloader will not shut down the worker processes after an epoch end
    sampler=dict(
        type='DefaultSampler',
        shuffle=False),  # Not shuffle during validation and testing
    dataset=dict(  # Config of validation dataset
        type=dataset_type,
        ann_file=ann_file_val,  # Path of annotation file
        exclude_file=exclude_file_val,  # Path of exclude annotation file
        label_file=label_file,  # Path of label file
        data_prefix=dict(img=data_root_val),  # Prefix of frame path
        proposal_file=proposal_file_val,  # Path of human detection proposals
        pipeline=val_pipeline,
        test_mode=True))
test_dataloader = val_dataloader  # Config of testing dataloader

# evaluation settings
val_evaluator = dict(  # Config of validation evaluator
    type='AVAMetric',
    ann_file=ann_file_val,
    label_file=label_file,
    exclude_file=exclude_file_val)
test_evaluator = val_evaluator  # Config of testing evaluator

train_cfg = dict(  # Config of training loop
    type='EpochBasedTrainLoop',  # Name of training loop
    max_epochs=20,  # Total training epochs
    val_begin=1,  # The epoch that begins validating
    val_interval=1)  # Validation interval
val_cfg = dict(  # Config of validation loop
    type='ValLoop')  # Name of validation loop
test_cfg = dict( # Config of testing loop
    type='TestLoop')  # Name of testing loop

# learning policy
param_scheduler = [ # Parameter scheduler for updating optimizer parameters, support dict or list
    dict(type='LinearLR',  # Decays the learning rate of each parameter group by linearly changing small multiplicative factor
        start_factor=0.1,  # The number we multiply learning rate in the first epoch
        by_epoch=True,  # Whether the scheduled learning rate is updated by epochs
  	  begin=0,  # Step at which to start updating the learning rate
  	  end=5),  # Step at which to stop updating the learning rate
    dict(type='MultiStepLR',  # Decays the learning rate once the number of epoch reaches one of the milestones
        begin=0,  # Step at which to start updating the learning rate
        end=20,  # Step at which to stop updating the learning rate
        by_epoch=True,  # Whether the scheduled learning rate is updated by epochs
        milestones=[10, 15],  # Steps to decay the learning rate
        gamma=0.1)]  # Multiplicative factor of learning rate decay

# optimizer
optim_wrapper = dict(  # Config of optimizer wrapper
    type='OptimWrapper',  # Name of optimizer wrapper, switch to AmpOptimWrapper to enable mixed precision training
    optimizer=dict(  # Config of optimizer. Support all kinds of optimizers in PyTorch. Refer to https://pytorch.org/docs/stable/optim.html#algorithms
        type='SGD',  # Name of optimizer
        lr=0.2,  # Learning rate
        momentum=0.9,  # Momentum factor
        weight_decay=0.0001),  # Weight decay
    clip_grad=dict(max_norm=40, norm_type=2))  # Config of gradient clip

# runtime settings
default_scope = 'mmaction'  # The default registry scope to find modules. Refer to https://mmengine.readthedocs.io/en/latest/tutorials/registry.html
default_hooks = dict(  # Hooks to execute default actions like updating model parameters and saving checkpoints.
    runtime_info=dict(type='RuntimeInfoHook'),  # The hook to updates runtime information into message hub
    timer=dict(type='IterTimerHook'),  # The logger used to record time spent during iteration
    logger=dict(
        type='LoggerHook',  # The logger used to record logs during training/validation/testing phase
        interval=20,  # Interval to print the log
        ignore_last=False), # Ignore the log of last iterations in each epoch
    param_scheduler=dict(type='ParamSchedulerHook'),  # The hook to update some hyper-parameters in optimizer
    checkpoint=dict(
        type='CheckpointHook',  # The hook to save checkpoints periodically
        interval=3,  # The saving period
        save_best='auto',  # Specified metric to mearsure the best checkpoint during evaluation
        max_keep_ckpts=3),  # The maximum checkpoints to keep
    sampler_seed=dict(type='DistSamplerSeedHook'),  # Data-loading sampler for distributed training
    sync_buffers=dict(type='SyncBuffersHook'))  # Synchronize model buffers at the end of each epoch
env_cfg = dict(  # Dict for setting environment
    cudnn_benchmark=False,  # Whether to enable cudnn benchmark
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), # Parameters to setup multiprocessing
    dist_cfg=dict(backend='nccl')) # Parameters to setup distributed environment, the port can also be set

log_processor = dict(
    type='LogProcessor',  # Log processor used to format log information
    window_size=20,  # Default smooth interval
    by_epoch=True)  # Whether to format logs with epoch type
vis_backends = [  # List of visualization backends
    dict(type='LocalVisBackend')]  # Local visualization backend
visualizer = dict(  # Config of visualizer
    type='ActionVisualizer',  # Name of visualizer
    vis_backends=vis_backends)
log_level = 'INFO'  # The level of logging
load_from = ('https://download.openmmlab.com/mmaction/v1.0/recognition/slowonly/'
             'slowonly_imagenet-pretrained-r50_8xb16-4x16x1-steplr-150e_kinetics400-rgb/'
             'slowonly_imagenet-pretrained-r50_8xb16-4x16x1-steplr-150e_kinetics400-rgb_20220901-e7b65fad.pth')  # Load model checkpoint as a pre-trained model from a given path. This will not resume training.
resume = False  # Whether to resume from the checkpoint defined in `load_from`. If `load_from` is None, it will resume the latest checkpoint in the `work_dir`.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/717151.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FreeRTOS学习笔记—任务挂起和恢复

文章目录 一、任务挂起和恢复API函数1.1 vTaskSuspend()函数1.2 vTaskResume()函数1.3 xTaskResumeFromISR()函数 二、任务挂起和恢复2.1 任务1挂起解挂任务22.2 中断中解挂任务1 三、补充内容3.1 FreeRTOS数据类型3.2 中断优先级分组3.3 错误问题 一、任务挂起和恢复API函数 …

ChatGPT | Word文档如何更好地提取表格内容给ChatGPT

本文来自http://blog.csdn.net/hellogv/ &#xff0c;引用必须注明出处&#xff01; Word文档如何更好地提取表格内容给ChatGPT做知识库&#xff0c;这属于文本预处理工作。 本文只讲思路、测试结果&#xff0c;技术实现用Python和Java都能完成&#xff0c;下一篇文章再贴源码…

Python实用工具--全python制作一个音乐下载器

前言 又来展示一下关于Python的实用小技巧了&#xff0c;这次就来分享分享–如何用Python来制作一个音乐下载器 做这个有什么用啊&#xff0c;我只能说&#xff0c;可以免费下载歌曲啊&#xff0c;这样就能每月保住自己钱包咯 效果展示 基本界面 图片以及文字都是可以自己更…

《动手学深度学习》——线性神经网络

参考资料&#xff1a; 《动手学深度学习》 3.1 线性回归 3.1.1 线性回归的基本元素 样本&#xff1a; n n n 表示样本数&#xff0c; x ( i ) [ x 1 ( i ) , x 2 ( i ) , ⋯ , x d ( i ) ] x^{(i)}[x^{(i)}_1,x^{(i)}_2,\cdots,x^{(i)}_d] x(i)[x1(i)​,x2(i)​,⋯,xd(i)​…

序列化对象

1&#xff1a;对象序列化 以内存为基准&#xff0c;把内存中的对象存储到磁盘文件中去&#xff0c;称为对象序列化。使用到的流是对象字节输出流&#xff1a;ObjectOutputStream 2&#xff1a;对象要序列化&#xff0c;必须实现Serializable序列化接口 2&#xff1a;对象反序…

二十四、HTTPS

文章目录 一、HTTPS&#xff08;一&#xff09;定义&#xff08;二&#xff09;HTTP与HTTPS1.端口不同&#xff0c;是两套服务2.HTTP效率更高&#xff0c;HTTPS更安全 &#xff08;三&#xff09;加密&#xff0c;解密&#xff0c;密钥等概念&#xff08;四&#xff09;为什么要…

【H5】文件下载(javascript)

系列文章 【移动设备】iData 50P 技术规格 本文链接&#xff1a;https://blog.csdn.net/youcheng_ge/article/details/130604517 【H5】avalon前端数据双向绑定 本文链接&#xff1a;https://blog.csdn.net/youcheng_ge/article/details/131067187 【H5】安卓自动更新方案&a…

hivesql列转行

原表&#xff1a; 目标表&#xff1a; sql代码&#xff1a; select dp as 日期 ,city_name as 城市, split_part(subject,‘:’,1) as 指标, cast( split_part(subject,‘:’,2) as double ) as 数值 from( select trans_array(2,‘;’,dp,city_name,subject) as (dp,city_na…

探秘高逼格艺术二维码的制作过程-AI绘画文生图

前几天看到几个逼格比较高的二维码&#xff0c;然后自己动手做了一下&#xff0c;给大家看看效果&#xff1a; 1、文生图&#xff08;狮子&#xff09;&#xff1a; 2、文生图&#xff08;城市&#xff09;&#xff1a; 下边将开始介绍怎么做的&#xff0c;有兴趣的可以继续读…

Vault AppRole最佳实现过程

AppRole AppRole身份验证方法允许机器或应用程序使用 Vault 定义的角色进行身份验证。AppRole 的开放式设计支持使用不同的工作流和配置来应对大量应用程序。这种身份验证方法主要是面向自动化工作流程(机器和服务)设计的,对人类操作者不太有用。 “AppRole”代表一组 Vau…

大数据Doris(五十六):RESOTRE数据恢复

文章目录 RESOTRE数据恢复 一、RESTORE数据恢复原理 二、RESTORE 数据恢复语法 三、RESOTRE数据恢复案例 1、在 Doris 集群中创建 mydb_recover 库 2、执行如下命令恢复数据 3、查看 restore 作业的执行情况 四、注意事项 RESOTRE数据恢复 Doris 支持BACKUP方式将当前…

力扣 40. 组合总和 II

题目来源&#xff1a;https://leetcode.cn/problems/combination-sum-ii/description/ C题解&#xff1a; 这道题的难点在于解集中不能包含重复的组合。如果用set去重会造成超时&#xff0c;所以只能在单层递归逻辑中处理。通过识别下一个数与当前数是否相同&#xff0c;来修改…

抖音小程序--开启沙盒模式后一直报,获取白名单失败:您没有权限访问此应用

一. 出现问题 按照抖音开发文档创建沙盒环境&#xff0c;然后替换appid后一直报无权限&#xff0c;如下图&#xff1a; 最后才发现&#xff0c;登录抖音开发工具的账户必须是超级管理员账户&#xff0c;添加的协助开发者&#xff0c;就算给了全部权限&#xff0c;也依然会报上面…

Navicat 入选中国信通院发布的《中国数据库产业图谱(2023)》

7 月 4 日&#xff0c;2023 年可信数据库发展大会主论坛在北京国际会议中心成功召开。会上&#xff0c;中国信息通信研究院正式发布《中国数据库产业图谱&#xff08;2023&#xff09;》。作为中国数据库生态工具供应商&#xff0c;凭借易用、稳定、可靠的产品力&#xff0c;以…

【C++】4.工具:读取yaml配置信息

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍读取yaml配置信息。 学其所用&#xff0c;用其所学。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff0c;下次更新不迷路&…

web学习1--maven--项目管理工具

写在前面&#xff1a; 这学期搞主攻算法去了&#xff0c;web的知识都快忘了。开始复习学习了。 文章目录 maven介绍功能介绍maven安装jar包搜索仓库 pom文件项目介绍父工程依赖管理属性控制可选依赖构建 依赖管理依赖的传递排除依赖可选依赖 maven生命周期分模块开发模块聚合…

产品的帮助中心怎么建设?关于帮助文档的7个小技巧

用户使用产品的过程中&#xff0c;常常会遇到与产品使用相关的问题。这时候&#xff0c;用户通常会面临三个选择&#xff1a;1.寻找客服的帮助 2.阅读产品帮助文档 3.放弃使用产品。 显然&#xff0c;对于企业而言&#xff0c;当然是希望能够帮助用户解决问题&#xff0c;使其…

shiro入门

1、概述 Apache Shiro 是一个功能强大且易于使用的 Java 安全(权限)框架。借助 Shiro 您可以快速轻松地保护任何应用程序一一从最小的移动应用程序到最大的 Web 和企业应用程序。 作用&#xff1a;Shiro可以帮我们完成 &#xff1a;认证、授权、加密、会话管理、与 Web 集成、…

车载通信,来看看SOA架构通信如何跨系统的

SOA车载跨系统通信 在车载系统中实现跨系统通信时&#xff0c;SOA架构&#xff08;Service-Oriented Architecture&#xff0c;面向服务的架构&#xff09;可以提供一种有效的解决方案。以下是一种基于SOA的车载跨系统通信的概述&#xff1a; 定义服务接口&#xff1a;首先&a…

2023年无线蓝牙耳机排行榜,十款无线蓝牙耳机品牌推荐

蓝牙耳机作为现代生活必备的电子产品之一&#xff0c;我们在选购时的选择就显得尤为重要。随着各大科技公司对蓝牙耳机功能的不断完善&#xff0c;用户对于耳机的期望也越来越高&#xff0c;音质、性能、降噪、舒适度等方面都成为了用户选择蓝牙耳机时考虑的因素。接下来我们一…