WandB 简明教程【Weights Bias】

news2024/11/18 11:55:39

在机器学习实验领域,调整超参数类似于微调复杂机器的旋钮和刻度盘。这些参数通常很微妙但至关重要,能够显著影响我们模型的性能和行为。WandB(权重和偏差 ) 是一个强大的在线工具集,旨在简化模型训练、评估和分析的过程。

随着我们开始探索,我们将揭示超参数的本质,这些关键变量决定了我们的机器学习模型的行为和功效。通过 wandb 的视角,我们将深入研究跟踪这些参数、可视化模型性能以及将实验无缝集成到我们的工作流程中的艺术。加入我们,揭开超参数管理的复杂性,为充分利用 wandb 在机器学习方面的潜力奠定基础。

 NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - AI模型在线查看 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割

1、环境准备

首先我们需要在wandb上建立一个简单的设置,以帮助我们监控模型性能。为此,我们需要首先在wandb上创建用户登录凭据。

创建一组登录凭据后,你需要获取 API 登录密钥,该密钥将位于你的“用户设置”中。接下来,需要使用以下命令安装 wandb API:

pip install wandb

接下来你将使用以下命令:

wandb login

并提供 API 密钥和登录用户名。

2、测试问题

现在我们将制作一个简单的深度学习模型,我们的 wandb 应用程序将集成到该模型中。对于本文,我决定在 TensorFlow 中训练一个模型,该模型将检测某人是否在微笑。这是一个简单的两类分类问题。

我们首先从导入必要的库开始:

import tensorflow as tf
from keras import layers
import matplotlib.pyplot as plt
import keras
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc
import seaborn as sns
from keras.models import Sequential, load_model
from keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.layers import BatchNormalization
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
import wandb
import h5py, os
import numpy as np
from wandb.keras import WandbCallback

接下来,我们将初始化 wandb API 并在其中创建我们的项目。此项目标题将帮助我们将其他项目区分开来。对于本文,项目名称将是“smile-no-smile”:

wandb.init(project="smile-no-smile")

完成后,我们将进行初始数据加载和预处理阶段:

path_to_train = "Data/happy/train_happy.h5/train_happy.h5"
path_to_test = "Data/happy/test_happy.h5/test_happy.h5"

### check and load the data
if not os.path.exists(path_to_train):
    raise ValueError("Can't find the train dataset")
else:
    train_dataset = h5py.File(path_to_train, "r")
    test_dataset = h5py.File(path_to_test, "r")

# Normalize the images
train_x = train_dataset["train_set_x"][:] / 255.
test_x = test_dataset["test_set_x"][:] / 255.

# Get the labels
train_y = train_dataset["train_set_y"][:]
test_y = test_dataset["test_set_y"][:]

# We will also split the train set into train and validation set
from sklearn.model_selection import train_test_split
train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size=0.2, random_state=42)

为了可视化我们的数据,我们将绘制图像及其各自的标签,如下所示:

### we use matplotlib to plot sample images for each class

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

plt.figure(figsize=(10,10))
for i in range(5):
    plt.subplot(1,5,i+1)
    plt.imshow(train_dataset["train_set_x"][i])
    plt.axis("off")
    plt.title("Label: " + str(train_dataset["train_set_y"][i]))

plt.show()

这将得到如下输出:

这里标签 0 表示没有微笑,1 表示微笑

为了训练模型,我们还将对图像进行一些预处理和增强,以便更好地训练模型。我们可以使用 keras 制作一个简单的数据增强管道,它将是位于模型的输入和内部层之间的一层。实际上,它看起来像这样:

data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(0.2),
        layers.RandomContrast(0.3),
        layers.RandomBrightness(0.3),
    ]
)

### Model with data augmentaion pipeline ###
inputs = keras.Input(shape=(64, 64, 3))
x = data_augmentation(inputs)
x = layers.Conv2D(filters=32, kernel_size=5, activation="relu")(inputs)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

如我们所见,代码中已经包含了数据增强部分。

现在对于训练部分,我们还可以实施一些不同的技术来帮助模型更快地达到局部最小值。其中一种方法是使用衰减学习率。这也有一些额外的好处,例如:

  • 更快地达到最小值
  • 在达到更高的准确度之前降低模型停滞的风险。
  • 获得更好的整体学习

为此,我们将定义一个函数,该函数将使用 keras.callbacks 根据时期动态更新学习率,即它将从较高的学习率开始,并在接近最终时期时不断降低学习率。

def step_decay(epoch):
    initial_lr = 0.001  # Initial learning rate
    drop = 0.5  # Learning rate drop factor
    epochs_drop = 5  # Number of epochs after which learning rate will drop
    new_lr = initial_lr * (drop ** (epoch // epochs_drop))
    return new_lr

### call the method step_decay() using keras learningRateScheduler method
lr_scheduler_exponential_decay = LearningRateScheduler(step_decay)

最后,我们将编译模型并有一个检查点回调,它将从模型停止训练的时间点加载并继续训练过程:

modelName = "happ-sadM.h5"
try:
    model = load_model(modelName)
    print("Loaded model from disk")
except:
    print("No model found, creating new one")


checkpoint = ModelCheckpoint(
	'happ-sadM.h5', monitor='val_loss', save_best_only=True)
### we will now compile the model
model.compile(optimizer='adam',
              loss="binary_crossentropy",
              metrics=['accuracy'])

3、WandB API 和指标

现在我们的模型已准备好进行训练和测试,我们将设置所有回调并将其连接到 WandB API,并查看 API 能够捕获的各种见解。

history = model.fit(train_x, train_y, epochs=40, 
          validation_data=(val_x, val_y),
          callbacks=[checkpoint, lr_scheduler_exponential_decay, WandbCallback()])

WandbCallback() 方法将为我们处理大部分指标捕获。但是,如果我们想添加一些其他信息,例如混淆矩阵或准确度分数,我们可以使用 wandb.log() 来实现相同的目的。

例如,我们可以将以下代码片段添加到上面的代码中:

preds = model.predict(test_x)
preds = np.round(preds).astype(int).reshape(1, -1)[0]

fpr, tpr, thresholds = roc_curve(test_y, preds)
roc_auc = auc(fpr, tpr)

wandb.log({'accuracy': accuracy_score(test_y, preds), "roc_curve": wandb.Image(plt)})

通过使用 wandb.log(),我们可以将某些输出(如图表)添加为图像,我们可以在 wandb 仪表板上看到这些图像:

WandB 仪表板的指标

如我们所见,我们的仪表板将提供对模型进行完整评估所需的所有指标,同时还将提供 GPU 利用率,这在云资源上训练大型模型时非常有用。

4、结束语

从这篇文章中,我们能够评估一个简单的模型并提供必要的详细信息。在下一篇文章中,我们将使用 wandb 工具集中的 Sweep 方法来找到我们模型的最佳参数集。


原文链接:WandB 简明教程 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2052385.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TCP shutdown 之后~

目录 摘要 1 API 2 shutdown(sockfd, SHUT_WR) 3 shutdown(sockfd, SHUT_WR) 4 kernel 是怎么做的? 附 摘要 通过 shutdown() 关闭读写操作,会发生什么?具体点呢,考虑两个场景: 场景一:C 发送数据完毕…

VBA技术资料MF184:图片导入Word添加说明文字设置格式

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。“VBA语言専攻”提供的教程一共九套,分为初级、中级、高级三大部分,教程是对VBA的系统讲解&#…

C ++初阶:C++入门级知识点

🍺0.前言 言C之言,聊C之识,以C会友,共向远方。各位博友的各位你们好啊,这里是持续分享C知识的小赵同学,今天要分享的C知识是C入门知识点,在这一章,小赵将会向大家展开聊聊C入门知识…

基于Mediapipe的手势识别系统 | OpenCV | Mediapipe | C++ | QT | Python | C# | Unity

基于Mediapipe的手势识别系统 OpenCV、Mediapipe C (QT)、Python (PyCharm)、C# (Visual Studio) Unity 3D 登录界面 图片手势识别 视频文件手势识别 摄像头实时手势识别 演示视频 基于Mediapipe的手势识别系统

UDP和TCP协议段格式分析

目录 UDP协议 特点 UDP协议的缓冲区 UDP协议段格式 TCP协议 特点 如何理解TCP是传输控制协议? TCP协议段格式 四位首部长度 16位窗口大小 32位序号 32位确认序号 TCP/IP四层模型: UDP协议 UDP(User Datagram Protocol &#xff…

十大护眼落地灯品牌哪款好?十大护眼落地灯品牌

十大护眼落地灯品牌哪款好?根据国际市场的研究数据表明,我国在日常生活中对电子产品的依赖度极高,每天看电子产品的时间超过8小时,出现眼睛酸痛、干涩、视觉疲劳的人群也不再少数,而给眼睛带来伤害的除了电子产品中所含…

界面控件DevExpress ASP.NET Web Forms v24.1最新版本系统环境配置要求

本文档包含有关安装和使用 DevExpress ASP.NET Web Forms控件的系统要求的信息。 点击获取DevExpress v24.1正式版 .NET Framework DevExpress ASP.NET Web Forms控件支持以下.NET框架版本。 如果您需要 DevExpress 产品的早期版本,请直接戳这里联系我>> …

MySQL中的EXPLAIN的详解

一、介绍 官网介绍: https://dev.mysql.com/doc/refman/5.7/en/explain-output.htmlhttps://dev.mysql.com/doc/refman/8.0/en/explain-output.htmlexplain(执行计划),使用explain关键字可以模拟优化器执行sql查询语句&#xff…

爆火的本地知识库项目是什么?什么是RAG?本地知识库与大模型的关系

“ 本地知识库就相当于大模型的外部资料库。” 很多人应该都听过本地知识库项目,它是当今人工智能领域爆火的项目之一,那么到底什么是本地知识库?它和大模型有什么关系?怎么构建本地知识库? 01 — 为什么需要本地知…

Docker的介绍、保姆级安装和使用

一、Docker简介 1.1、Docker是什么 Docker是一个用于开发、发布和运行应用程序的开放平台;使您能够将应用程序与基础设施分离,以便您可以快速交付软件。不像虚拟机那样笨重(比如:我需要将一个安装好nginx环境的内容分享给其他人: 方式一【使用虚拟】(应用程序Nginx与基…

系统架构设计师 - 软件工程(2)

软件工程 软件工程(13-22分)非常重要软件系统建模系统设计界面设计 ★★软件设计结构化设计 ★★面向对象设计 ★★★★★基本过程设计原则设计模式创建型模式:创建对象结构型模式:更大的结构行为型模式:交互及职责分配…

四川财谷通信息技术有限公司抖音小店优势解析

在数字经济蓬勃发展的今天,电商平台如雨后春笋般涌现,其中,四川财谷通信息技术有限公司旗下的抖音小店凭借其独特的优势和强大的实力,在众多竞争者中脱颖而出,成为消费者和商家信赖的优选平台。本文将详细解析四川财谷…

Windows键快捷键大全

Windows键快捷键大全 Windows键结合其他键可以执行多种快捷操作,以下是一些常用的Windows键快捷键: Windows键 D: 显示或隐藏桌面。Windows键 E: 打开文件资源管理器。Windows键 L: 锁定电脑。Windows键 R: 打开运行对话框。Windows键 I: 打开Win…

Java中JDK动态代理

参考:疯狂Java讲义 第18章 文章目录 前言复杂度与耦合的矛盾 使用JDK动态代理总结 前言 复杂度与耦合的矛盾 开发实际应用的软件系统时,通常会存在相同代码段重复出现的情况,在这种情况下,一般都提取为一个方法,在不…

SOP企业内部推行:效率飙升100%,质量保障零瑕疵!

在企业的日常运营中,你是否经常遇到这样的问题:同样一项工作,不同的人做出来效果却大相径庭?或者,明明已经制定了工作流程,但执行起来却总是出现偏差,导致效率低下、质量不稳?这些问…

【STM32单片机_(HAL库)】3-2-2【中断EXTI】【电动车报警器项目】继电器定时开闭

1.硬件 STM32单片机最小系统继电器模块 2.软件 继电器模块alarm驱动文件添加GPIO常用函数main.c程序 #include "sys.h" #include "delay.h" #include "led.h" #include "alarm.h"int main(void) {HAL_Init(); …

海外服务器和内地服务器有什么区别?

海外服务器和内地服务器在许多方面存在区别,主要包括以下几个方面: 1. 地理位置 海外服务器:位于中国大陆以外的地区,比如美国、欧洲、东南亚等地。常见的海外服务器提供商有Amazon Web Services(AWS)、Goo…

稚晖君发布5款全能人形机器人,开源创新,全能应用

8月18日,智元机器人举行“智元远征 商用启航” 2024年度新品发布会,智元联合创始人彭志辉主持并发布了“远征”与“灵犀”两大系列共五款商用人形机器人新品——远征A2、远征A2-W、远征A2-Max、灵犀X1及灵犀X1-W,并展示了在机器人动力、感知、…

【LLM之Base Model】Weaver论文阅读笔记

研究背景 当前的大型语言模型(LLM)如GPT-4等,尽管在普通文本生成中表现出色,但在创造性写作如小说、社交媒体内容等方面,往往不能很好地模仿人类的写作风格。这些模型在训练和对齐阶段,往往使用的是大规模…

Java | Leetcode Java题解之第347题前K个高频元素

题目&#xff1a; 题解&#xff1a; class Solution {public int[] topKFrequent(int[] nums, int k) {Map<Integer, Integer> occurrences new HashMap<Integer, Integer>();for (int num : nums) {occurrences.put(num, occurrences.getOrDefault(num, 0) 1);…