30分钟吃掉wandb可视化自动调参

news2024/12/28 22:32:09

wandb.sweep: 低代码,可视化,分布式 自动调参工具。

使用wandb 的 sweep 进行超参调优,具有以下优点。

(1)低代码:只需配置一个sweep.yaml配置文件,或者定义一个配置dict,几乎不用编写调参相关代码。

(2)可视化:在wandb网页中可以实时监控调参过程中每次尝试,并可视化地分析调参任务的目标值分布,超参重要性等。

(3)分布式:sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。

公众号后台回复关键词:wandb,获取本文notebook代码和B站视频演示。

使用 wandb 的sweep 调参的缺点:

需要联网:由于wandb的controller位于wandb的服务器机器上,wandb日志也需要联网上传,在没有互联网的环境下无法正常使用wandb 进行模型跟踪 以及 wandb sweep 可视化调参。

d6731f3afe349a385fa50a5eb394b50a.png

〇,使用Sweep的3步骤

  1. 配置 sweep_config

配置调优算法,调优目标,需要优化的超参数列表 等等。
  1. 初始化 sweep controller:

sweep_id = wandb.sweep(sweep_config,project)
  1. 启动 sweep agents:

wandb.agent(sweep_id, function=train)
import os,PIL 
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch 
from torch import nn 
import torchvision 
from torchvision import transforms
import datetime
import wandb 

wandb.login()
from argparse import Namespace

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#初始化参数配置
config = Namespace(
    project_name = 'wandb_demo',
    
    batch_size = 512,
    
    hidden_layer_width = 64,
    dropout_p = 0.1,
    
    lr = 1e-4,
    optim_type = 'Adam',
    
    epochs = 15,
    ckpt_path = 'checkpoint.pt'
)

一. 配置 Sweep config

详细配置文档可以参考:https://docs.wandb.ai/guides/sweeps/define-sweep-configuration

1,选择一个调优算法

Sweep支持如下3种调优算法:

(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。

(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。

(3)贝叶斯搜索:bayes. 创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。

sweep_config = {
    'method': 'random'
    }

2,定义调优目标

设置优化指标,以及优化方向。

sweep agents 通过 wandb.log 的形式向 sweep controller 传递优化目标的值。

metric = {
    'name': 'val_acc',
    'goal': 'maximize'   
    }
sweep_config['metric'] = metric

3,定义超参空间

超参空间可以分成 固定型,离散型和连续型。

  • 固定型:指定 value

  • 离散型:指定 values,列出全部候选取值。

  • 连续性:需要指定 分布类型 distribution, 和范围 min, max。用于 random 或者 bayes采样。

sweep_config['parameters'] = {}

# 固定不变的超参
sweep_config['parameters'].update({
    'project_name':{'value':'wandb_demo'},
    'epochs': {'value': 10},
    'ckpt_path': {'value':'checkpoint.pt'}})

# 离散型分布超参
sweep_config['parameters'].update({
    'optim_type': {
        'values': ['Adam', 'SGD','AdamW']
        },
    'hidden_layer_width': {
        'values': [16,32,48,64,80,96,112,128]
        }
    })

# 连续型分布超参
sweep_config['parameters'].update({
    
    'lr': {
        'distribution': 'log_uniform_values',
        'min': 1e-6,
        'max': 0.1
      },
    
    'batch_size': {
        'distribution': 'q_uniform',
        'q': 8,
        'min': 32,
        'max': 256,
      },
    
    'dropout_p': {
        'distribution': 'uniform',
        'min': 0,
        'max': 0.6,
      }
})

4,定义剪枝策略 (可选)

可以定义剪枝策略,提前终止那些没有希望的任务。

sweep_config['early_terminate'] = {
    'type':'hyperband',
    'min_iter':3,
    'eta':2,
    's':3
} #在step=3, 6, 12 时考虑是否剪枝
from pprint import pprint
pprint(sweep_config)

二. 初始化 sweep controller

sweep_id = wandb.sweep(sweep_config, project=config.project_name)

三, 启动 Sweep agent

我们需要把模型训练相关的全部代码整理成一个 train函数。

def create_dataloaders(config):
    transform = transforms.Compose([transforms.ToTensor()])
    ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)

    ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))
    dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, 
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val
def create_net(config):
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) 
    net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width,
                                     out_channels=config.hidden_layer_width,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = config.dropout_p))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(config.hidden_layer_width,10))
    return net
def train_epoch(model,dl_train,optimizer):
    model.train()
    for step, batch in enumerate(dl_train):
        features,labels = batch
        features,labels = features.to(device),labels.to(device)

        preds = model(features)
        loss = nn.CrossEntropyLoss()(preds,labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    return model
def eval_epoch(model,dl_val):
    model.eval()
    accurate = 0
    num_elems = 0
    for batch in dl_val:
        features,labels = batch
        features,labels = features.to(device),labels.to(device)
        with torch.no_grad():
            preds = model(features)
        predictions = preds.argmax(dim=-1)
        accurate_preds =  (predictions==labels)
        num_elems += accurate_preds.shape[0]
        accurate += accurate_preds.long().sum()

    val_acc = accurate.item() / num_elems
    return val_acc
def train(config = config):
    dl_train, dl_val = create_dataloaders(config)
    model = create_net(config); 
    optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)
    #======================================================================
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True)
    model.run_id = wandb.run.id
    #======================================================================
    model.best_metric = -1.0
    for epoch in range(1,config.epochs+1):
        model = train_epoch(model,dl_train,optimizer)
        val_acc = eval_epoch(model,dl_val)
        if val_acc>model.best_metric:
            model.best_metric = val_acc
            torch.save(model.state_dict(),config.ckpt_path)   
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")
        #======================================================================
        wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})
        #======================================================================        
    #======================================================================
    wandb.finish()
    #======================================================================
    return model   

#model = train(config)

一切准备妥当,点火🔥🔥。

# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)

四,调参可视化和跟踪

1,平行坐标系图

可以直观展示哪些超参数组合更加容易获取更好的结果。

7366fd427d9e456030174fd9764948ec.png


2,超参数重要性图

可以显示超参数和优化目标最终取值的重要性,和相关性方向。

79447d4928b8ac70a543e57a9c7141f7.png


caa2b93c396396eb37a888a052bd7cdf.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/340245.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Django框架之视图和URL

视图和URL 站点管理页面做好了, 接下来就要做公共访问的页面了.对于Django的设计框架MVT. 用户在URL中请求的是视图.视图接收请求后进行处理.并将处理的结果返回给请求者.使用视图时需要进行两步操作 1.定义视图2.配置URLconf 1. 定义视图 视图就是一个Python函数&#xff0c…

没有她的通讯录(C语言实现)

🚀write in front🚀 📝个人主页:认真写博客的夏目浅石. 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝 📣系列专栏:夏目的C语言宝藏 💬总结:希望你看完之…

优劣解距离法TOPSIS——清风老师

TOPSIS法是一种常用的综合评价方法,能充分利用原始数据的信息,其结果能精确地反映各评价方案之间的差距。 基本过程为先将原始数据矩阵统一指标类型(一般正向化处理)得到正向化的矩阵,再对正向化的矩阵进行标准化处理…

​ICLR 2023 | GReTo:以同异配关系重新审视动态时空图聚合

©PaperWeekly 原创 作者 | 周正阳单位 | 中国科学技术大学论文简介动态时空图数据结构在多种不同的学科中均普遍存在,如交通流、空气质量观测、社交网络等,这些观测往往会随着时间而变化,进而引发节点间关联的动态时变特性。本文的主要…

springboot学习(八十) springboot中使用Log4j2记录分布式链路日志

在分布式环境中一般统一收集日志,但是在并发大时不好定位问题,大量的日志导致无法找出日志的链路关系。 可以为每一个请求分配一个traceId,记录日志时,记录此traceId,从网关开始,依次将traceId记录到请求头…

【C#】[带格式的字符串] 复合格式设置字符串与使用 $ 的字符串内插 | 如何格式化输出字符串

复合格式输出 string name "Fred"; String.Format("Name {0}, hours {1:hh}", name, DateTime.Now);通过指定相同的参数说明符,多个格式项可以引用对象列表中的同一个元素。 例如,通过指定“0x{0:X} {0:E} {0:N}”等复合格式字符…

凸优化学习:PART3凸优化问题(持续更新)

凸优化问题 凸优化问题的广义定义: 目标函数为凸函数约束集合为凸集 一、优化问题 基本用语 一般优化问题的描述: minimize⁡f0(x)subject to fi(x)⩽0,i1,⋯,mhi(x)0,i1,⋯,p(1)\begin{array}{ll} \operatorname{minimize} & f_0(x) \\ \text { s…

Centos7 安装Hadoop3 单机版本(伪分布式版本)

环境版本CentOS-7JDK-8Hadoop-3CentOS-7 服务器设置设置静态IP查看IP配置在/etc/sysconfig/network-scripts/目录下的ifcfg-ens33文件中。[rootHadoop3-master sbin]# cd /etc/sysconfig/network-scripts [rootHadoop3-master network-scripts]# ll 总用量 232 -rw-r--r--. 1 r…

云计算培训靠谱吗?

怎么算靠谱的培训呢? 举个例子: 我想参加云计算培训找个工作,机构满足了我的要求,有工作了,但是不是做云计算相关的。 小强也参加了云计算培训,想学好云计算成为技术大牛,最后专业学得普普通…

java环境配置

java环境配置步骤下载jdk安装jdk配置环境变量通过控制台命令验证配置是否成功大功告成安装教程: https://blog.csdn.net/m0_37220730/article/details/103585266 下载jdk 若不理解JDK/JRE/JVM的关系,可以点此查看初识Java(概念、版本迭代、…

DIN解读

传统的Embedding&MLP架构将用户特征编码进一个固定长度的向量。当推出一个商品时,该架构无法捕捉用户丰富的历史行为中的多样性兴趣与该商品的关联。阿里妈妈团队提出了DIN网络进行改进,主要有如下两点创新: 引入注意力机制来捕捉历史行为…

【Linux下代码调试工具】gdb 的基本使用

gdb的基本使用前言准备gdb工具调试须知gdb的基本指令进入调试退出调试显示代码及函数内容运行程序给程序打断点查看断点位置断点使能取消断点逐过程调试逐语句调试运行到下一个断点查看变量的值变量值常显示取消变量值常显示前言 在主页前面的几篇文章已经介绍了Vim编辑器及Ma…

C语言(内联函数(C99)和_Noreturn)

1.内联函数 通常,函数调用都有一定的开销,因为函数的调用过程包含建立调用,传递参数,跳转到函数代码并返回。而使用宏是代码内联,可以避开这样的开销。 内联函数:使用内联diamagnetic代替函数调用。把函数…

【MySQL】 事务

😊😊作者简介😊😊 : 大家好,我是南瓜籽,一个在校大二学生,我将会持续分享Java相关知识。 🎉🎉个人主页🎉🎉 : 南瓜籽的主页…

opencv窗口的创建/显示/销毁

大家好,我是csdn的博主:lqj_本人 这是我的个人博客主页: lqj_本人的博客_CSDN博客-微信小程序,前端,python领域博主lqj_本人擅长微信小程序,前端,python,等方面的知识https://blog.csdn.net/lbcyllqj?spm1011.2415.3001.5343哔哩哔哩欢迎关注…

Docker-harbor私有仓库

一、Harbor概述 1、Harbor的概念 • Harbor是VMware公司开源的企业级Docker Registry项目,其目标是帮助用户迅速搭建一个企业级的Docker Registry服务 • Harbor以 Docker 公司开源的Registry 为基础,提供了图形管理UI、基于角色的访问控制(Role Base…

力扣SQL刷题9

目录1107. 每日新用户统计-勉强579. 查询员工的累计薪水 - 各种绕易错点:range与rows区别615. 平均工资:部门与公司比较with建临时表注意点1107. 每日新用户统计-勉强 题型:每个日期中首次登录人数 解答:从原表中用按用户分组后m…

嵌入式开发----示波器入门

示波器入门前言一、示波器介绍关键指标工作原理二、功能按钮介绍三、一键入门四、 典型应用场景校准捕捉测试总线通讯总结前言 对于嵌入式工程师来说,示波器的使用极为重要,他就像是“电子工程师的眼睛”,把被测信号的实际波形显示在屏幕上&…

Java特性之设计模式【桥接模式】

一、桥接模式 概述 桥接(Bridge)是用于把抽象化与实现化解耦,使得二者可以独立变化。这种类型的设计模式属于结构型模式,它通过提供抽象化和实现化之间的桥接结构,来实现二者的解耦 这种模式涉及到一个作为桥接的接口…

67. 二进制求和

文章目录题目描述竖式模拟转换为十进制计算题目描述 给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 示例 1: 输入:a “11”, b “1” 输出:“100” 示例 2: 输入:a “1010”, b “1011” …