pytorch-lightning中使用wandb实现超参数搜索

news2025/1/26 15:50:56

由于最近涉及下游任务微调,预训练任务中的框架使用的是pytorch-lightning,使用了典型的VLP(vision-language modeling)的训练架构,如Vilt代码中:https://github.com/dandelin/ViLT,这类架构中只涉及到预训练,但是在下游任务中微调没有出现如何调参的过程。因此可以使用wandb的sweeps来对下游任务进行超参数搜索。

    • 问题

Vilt的目录结构:

这类预训练大模型,涉及到大量的参数,这些参数均使用Sacred框架进行统一管理(放在上图中的config.py文件中),其中大部分的参数是固定的(即预训练模型固定参数),下游任务只是对学习率、batch_size、最后一层全连接层等的配置,因此我们超参数搜索只是其中一小部分,然而对wandb的超参数搜索sweeps来说,它有自带的参数管理,因此两者的参数管理会存在冲突。

2、解决问题

第一步:定义sweeps 配置,设置超参数搜索方法和搜索范围

新建一个文件夹,用于存储sweeps的配置,命名为sweeps_config.py:

# sweeps_config.py
import math

sweep_config = {
    "name": "sweep_with_launchpad", # 自定义,用于命名sweep超参数的名称
    "metric": {"name": "val/the_metric", "goal": "maximize"}, # 监控指标,name值应为wandb.log对象中出现key值,goal为maximize或者minimize

    "method": "grid", # 搜索方式,这里为网格搜索,还有"random"等设置
    "parameters": { # 搜索的范围
        "batch_size": {
            "value": 128 
        },
        "max_steps": {
            "value": 100 # 如果是值,注意为"value"
        },
        "max_epoch": {
            "value": 100
        },
        "learning_rate": {
            "values": [5e-6,1e-5,5e-5] # 如果是范围,注意为"values"
        }
    }
}

第二步:初始化sweep

在主函数main中,初始化sweep:

sweep_id = wandb.sweep(sweep_config,project='myCLIP') # project是在wandb中的项目名称

第三步:代码嵌入

由于模型已经预训练好了,模型结构基本不变,仅仅微调,因此新建finturn.py文件作为微调的运行文件,代码如下:

import os
import numpy as np
import random
import time
import datetime
import torch
import copy
from config import ex
import pytorch_lightning as pl
from datamodules.datamodules_multitask import MTDataModule
from models.myCLIP import myCLIP
from hyparam_search.sweep_config import sweep_config # 导入搜索范围
from pytorch_lightning.loggers import WandbLogger
import wandb
wandb.login()

os.environ["TOKENIZERS_PARALLELISM"] = "false"

@ex.automain
def main(_config):
    # 初始化参数
    start_time = time.time()
    # _config是Scale管理的参数,即config.py中的参数
    _config = copy.deepcopy(_config)

    if _config['is_pretrain'] == False: # 下游微调
        
        sweep_id = wandb.sweep(sweep_config,project='myCLIP') # 这是第二步中的初始化sweep

        # wandb.init(project="")  # 此处不能init,如果init了,会报错。
    
    # 训练函数
    def train(config=None):
        # 设置种子
        pl.seed_everything(_config["seed"])

#################################(下面为重点代码)#####################################

        with wandb.init(config=None): # 初始化wandb
            # print(wandb.config)
            config = wandb.config # 如果调用了wandb.agent函数,wandb.config会对sweep_config中的参数自动更新,选择一组未被使用过的超参数。每调用一次train,超参数config会更新一次
            print(config)
            _config.update(config) # 将config中选出的训练参数更新到_config中,用于训练模型
            print(_config)
#################################(下面为正常设置)#####################################

            dm = MTDataModule(_config, dist=False)

            model = myCLIP(_config)
            exp_name = f'{_config["exp_name"]}'
            # 日志打印文件
            os.makedirs(_config["log_dir"], exist_ok=True)
            # checkpoint保存配置
            checkpoint_callback = pl.callbacks.ModelCheckpoint(
                save_top_k=1,
                verbose=True,
                monitor="val/the_metric",  # 想监视的指标
                mode=_config['mode'],
                save_last=False,
                dirpath=_config['checkpoint_save_path'],
                filename="{epoch:02d}-{global_step}-64",
            )
            wandb_logger = WandbLogger(project="myCLIP")

            # 学习率回调函数
            lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
            callbacks = [checkpoint_callback, lr_callback]
            num_gpus = (
                _config["num_gpus"]
                if isinstance(_config["num_gpus"], int)
                else len(_config["num_gpus"])
            )
            # 4096 / (4*1*1)
            grad_steps = max(_config["batch_size"] // (
                    _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
            ), 1)
            max_steps = _config["max_steps"] if _config["max_steps"] is not None else None

            trainer = pl.Trainer(
                gpus=_config["num_gpus"],  # 使用gpu列表
                num_nodes=_config["num_nodes"],  # 节点数
                precision=_config["precision"],  # 指定训练精度
                accelerator="cuda",  #
                benchmark=True,
                deterministic=not _config['is_pretrain'],  # 预训练为False,用到了gather函数。微调用True,可复现
                max_epochs=_config["max_epoch"] if max_steps is None else 200,
                max_steps=max_steps,
                callbacks=callbacks,  # 回调函数,保存checkpoint
                logger=wandb_logger,  # 打印日志
                replace_sampler_ddp=False,  #
                accumulate_grad_batches=grad_steps,  # 每k次batches累计一次梯度
                log_every_n_steps=10,  # 更新n次网络权重后记录一次日志
                resume_from_checkpoint=_config["resume_from"],  #
                fast_dev_run=_config["fast_dev_run"],
                val_check_interval=_config["val_check_interval"],
                # strategy="ddp_find_unused_parameters_false"
            )

            # 训练
            if not _config["test_only"]:
                trainer.fit(model, datamodule=dm)

    # 调用agent函数,这是第四步:运行
    wandb.agent(sweep_id, train, count=50)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

第四步:运行

启动一个运行 10 次训练的agent,使用 Sweep Controller 返回的网格生成的超参数值。

sweep_id : 是第二步中初始化时的id。

train : 是第三步中main函数中嵌入的train函数。

count :一个整数值,自定义。

wandb.agent(sweep_id, train, count=10)

执行代码sh:

python finturn.py with task_finetune

第五步:实验记录

进入wandb官网(友好上网),登录后进入project(我的是"myCLIP"),进入后点击Sweeps列表,即可看到该次运行的结果:

参考资料:

https://zhuanlan.zhihu.com/p/436385177

https://docs.wandb.ai/ref/python/agent

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/185481.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

51单片机学习-5定时器与中断

5 定时器与中断 [toc] 注:笔记主要参考B站江科大自化协教学视频“51单片机入门教程-2020版 程序全程纯手打 从零开始入门”。 注:工程及代码文件放在了本人的Github仓库。 5.1 定时器原理与中断系统 5.1.1 定时器原理 CPU的时序指标有: 振…

C语言预处理命令是什么?

C语言源文件要经过编译、链接才能生成可执行程序:1) 编译(Compile)会将源文件(.c文件)转换为目标文件。对于 VC/VS,目标文件后缀为.obj;对于GCC,目标文件后缀为.o。编译是针对单个源…

ESP32设备驱动-ADS1015(ADC)驱动

ADS1015(ADC)驱动 1、ADS1015介绍 ADS1015 是一款具有 12 位分辨率的精密模数转换器 (ADC),采用超小型无引线 QFN-10 封装或 MSOP-10 封装。 ADS1015 的设计考虑了精度、功率和易于实施。 ADS1015 具有板载基准和振荡器。 数据通过 I2C 兼容的串行接口传输; 可以选择四个 I…

Portapack应用开发教程(十八)NavTex接收 C

有段时间没研究NavTex了,这段时间打算捡起来继续搞。 上一篇文章中,我用frisnit生成了wav文件。然后再用gnuradio观察波形,发现波形确实能与frisnit中的描述以及python解码程序中的dictionary对应上。 接下来,我要重点想办法自己…

Rust机器学习之petgraph

Rust机器学习之petgraph 图作为一种重要的数据结构和表示工具在科学和技术中无处不在。因此,许多软件项目会以各种形式用到图。尤其在知识图谱和因果AI领域,图是最基础的表达和研究工具。Python有著名的NetworksX库,便于用户对复杂网络进行创…

apt命令详解

apt(Advanced Packaging Tool)是一个在 Debian 和 Ubuntu 中的 Shell 前端软件包管理器。 apt 命令提供了查找、安装、升级、删除某一个、一组甚至全部软件包的命令,而且命令简洁而又好记。 apt 命令执行需要超级管理员权限(root)。前些日子…

基于java ssm springboot宠物用品商城系统

基于java ssm springboot宠物用品商城系统 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取源码联系方式…

Python 基础语法介绍(一)

文章目录一、概述二、变量1)变量定义2)定义变量的规则3)变量命名规范4)变量类型转换三、注释1)单行注释2)多行注释1、单引号()注释2、双引号("""&#xf…

Kubernetes 体验 kubecolor

Kubernetes 体验 kubecolorkubecolor 概述Github 地址安装 kubecolor设置.bashrc使用 kubecolorkubecolor 概述 对你的kubectl输出进行着色。 kubecolor在内部调用kubectl命令并尝试对输出进行着色,因此你可以将kubecolor作为kubectl的一个完整的替代品。这意味着…

JAVA经典面试题带答案(一)

目录 1、JDK 和 JRE 有什么区别? 2、 和 equals 的区别是什么? 3、final 在 java 中有什么作用? 4、java 中的 Math.round(-1.5) 等于多少? 5、String 属于基础的数据类型吗? 不属于。 6、String str"i&quo…

51单片机学习笔记-13直流电机

13 直流电机 [toc] 注:笔记主要参考B站江科大自化协教学视频“51单片机入门教程-2020版 程序全程纯手打 从零开始入门”。 注:工程及代码文件放在了本人的Github仓库。 13.1 直流电机与PWM波 13.1.1 直流电机 直流电机是一种将电能转换为机械能的装置…

Docker -- 部署Mysql主从服务

以下是配置一主两从的Mysql服务的具体流程。 文章目录创建用于挂载的目录修改cnf配置拉取mysql服务镜像自定义docker网络启动容器主库配置查看主库状态创建从库备份用户从库配置修改Master信息启动slave服务查看slave服务状态是否正常创建用于挂载的目录 保证数据的持久化&…

Databend 内幕大揭秘第二弹 - Data Source

本篇是 minibend 系列的第二期,将会介绍 Data Source 部分的设计与实现,当然,由于是刚开始涉及到编程的部分,也会提到包括 类型系统 和 错误处理 之类的一些额外内容。 前排指路视频和 PPT 地址 视频(哔哩哔哩&#xf…

23种设计模式之趣味学习篇

23种设计模式之趣味学习篇1. 设计模式概述1.1 什么是设计模式1.2 设计模式的好处2. 设计原则分类3. 详解3.1 单一职责原则3.2 开闭原则3.3 里氏代换原则3.4 依赖倒转原则3.5 接口隔离原则3.6 合成复用原则3.7 迪米特法则4. Awakening1. 设计模式概述 我们的软件开发技术也包括一…

【1669. 合并两个链表】

来源:力扣(LeetCode) 描述: 给你两个链表 list1 和 list2 ,它们包含的元素分别为 n 个和 m 个。 请你将 list1 中下标从 a 到 b 的全部节点都删除,并将list2 接在被删除节点的位置。 下图中蓝色边和节点…

【算法竞赛学习】csoj:寒假第二场

文章目录前言红包接龙最后一班勇者兔兔兔爱消除吃席兔知识拓展std::greater | 堆优化参考iota函数参考并查集参考sort自定义函数参考树形dp参考使用auto时控制分隔符前言 由于本人菜鸡,所以大多都是使用出题人的代码和思路 如有侵权,麻烦联系up删帖&…

pytorch_sparse教程

pytorch_sparse教程 Coalesce torch_sparse.coalesce(index, value, m, n, op"add") -> (torch.LongTensor, torch.Tensor) 逐行排序index并删除重复项。通过将重复项映射到一起来删除重复项。对于映射,可以使用任何一种torch_scatter操作。 参数 i…

来回修改的投标文件怎么做版本管理?1个工具搞定!

投标是公司市场活动中非常重要的事情,每次投标文件的编写像打仗一样,要修改很多次,不保存每个版本就只能在需要的时候后悔,多个文件、多人编写、多种方案要再最后的几个小时才能定,每次都是弄得鸡飞狗跳的,…

Python卷积神经网络CNN

Python卷积神经网络CNN 提示:前言 Python卷积神经网络CNN 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录Python卷积神经网络CNN前言一、导入包二、介绍三、卷积过滤四、权重五、展示特征图六、用 ReLU…

一文快速入门哈希表

目录一、基本概念1.1 哈希冲突二、整数哈希2.1 哈希函数的设计2.2 解决哈希冲突2.2.1 开放寻址法2.2.2 拉链法三、字符串哈希3.1 应用:重复的DNA序列References一、基本概念 哈希表又称散列表,一种以「key-value」形式存储数据的数据结构。所谓以「key-…