训练日志刷屏使我痛苦,我开发了VLog

news2025/3/10 15:07:16

训练日志刷屏使我痛苦,我开发了VLog,可以在任意训练代码中轻松使用~

例如,通过callback嵌入到lightgbm/catboost/transformers/ultralytics,乃至keras库的训练代码流程中~

before:

f863c0c0a2c38333f1d4585dec341cdc.png

after

b76aa5c885f754c3a94129980f6dae0d.gif

为什么不用tensorboard或者wandb?

tensorboard需要开端口权限,服务器开发环境有时候没有端口权限~

wandb需要联网,有时候网速很差或者没有网,影响体验~

综合对比考虑如下表

f99a200b86841b54b49068d426bb1dce.png

一,VLog基本原理

VLog类主要有以下5个方法。

from torchkeras import VLog

#1, 初始化方法
vlog = VLog(epochs=20, monitor_metric='val_loss', monitor_mode='min') 

#2, 显示开始空图表
vlog.log_start()

#3, 更新step级别日志
vlog.log_step({'train_loss':0.003,'val_loss':0.002}) 

#4, 更新epoch级别日志
vlog.log_epoch({'train_acc':0.9,'val_acc':0.87,'train_loss':0.002,'val_loss':0.03})

#5, 输出最终稳定状态图表
vlog.log_end()
import time
import math,random
from torchkeras import VLog

epochs = 10
batchs = 30

#0, 指定监控北极星指标,以及指标优化方向
vlog = VLog(epochs, monitor_metric='val_loss', monitor_mode='min') 

#1, log_start 初始化动态图表
vlog.log_start() 

for epoch in range(epochs):
    
    #train
    for step in range(batchs):
        
        #2, log_step 更新step级别日志信息,打日志,并用小进度条显示进度
        vlog.log_step({'train_loss':100-2.5*epoch+math.sin(2*step/batchs)}) 
        time.sleep(0.05)
        
    #eval    
    for step in range(20):
        
        #3, log_step 更新step级别日志信息,指定training=False说明在验证模式,只打日志不更新小进度条
        vlog.log_step({'val_loss':100-2*epoch+math.sin(2*step/batchs)},training=False)
        time.sleep(0.05)
        
    #4, log_epoch 更新epoch级别日志信息,每个epoch刷新一次动态图表和大进度条进度
    vlog.log_epoch({'val_loss':100 - 2*epoch+2*random.random()-1,
                    'train_loss':100-2.5*epoch+2*random.random()-1})  

# 5, log_end 调整坐标轴范围,输出最终指标可视化图表
vlog.log_end()

830d0af5b03eb6721cb623407037dcd6.png

二,在LightGBM中使用VLog

设计一个简单的回调,就可以搞定~

from torchkeras import VLog
class VLogCallback:
    def __init__(self, num_boost_round, 
                 monitor_metric='val_loss',
                 monitor_mode='min'):
        self.order = 20
        self.num_boost_round = num_boost_round
        self.vlog = VLog(epochs = num_boost_round, monitor_metric = monitor_metric, 
                         monitor_mode = monitor_mode)

    def __call__(self, env) -> None:
        metrics = {}
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                metrics[data_name+'_'+eval_name] = result
            else:
                data_name, eval_name = item[1].split()
                res_mean = item[2]
                res_stdv = item[4]
                metrics[data_name+'_'+eval_name] = res_mean
        self.vlog.log_epoch(metrics)
import datetime
import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

def printlog(info):
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n"+"=========="*8 + "%s"%nowtime)
    print(info+'...\n\n')

#================================================================================
# 一,读取数据
#================================================================================
printlog("step1: reading data...")

# 读取dftrain,dftest
breast = datasets.load_breast_cancer()
df = pd.DataFrame(breast.data,columns = [x.replace(' ','_') for x in breast.feature_names])
df['label'] = breast.target
df['mean_radius'] = df['mean_radius'].apply(lambda x:int(x))
df['mean_texture'] = df['mean_texture'].apply(lambda x:int(x))
dftrain,dftest = train_test_split(df)

categorical_features = ['mean_radius','mean_texture']
lgb_train = lgb.Dataset(dftrain.drop(['label'],axis = 1),label=dftrain['label'],
                        categorical_feature = categorical_features)

lgb_valid = lgb.Dataset(dftest.drop(['label'],axis = 1),label=dftest['label'],
                        categorical_feature = categorical_features,
                        reference=lgb_train)

#================================================================================
# 二,设置参数
#================================================================================
printlog("step2: setting parameters...")
                               
boost_round = 50                   
early_stop_rounds = 10

params = {
    'boosting_type': 'gbdt',
    'objective':'binary',
    'metric': ['auc'], #'l2'
    'num_leaves': 15,   
    'learning_rate': 0.05,
    'feature_fraction': 0.9,
    'bagging_fraction': 0.8,
    'bagging_freq': 5,
    'verbose': 0,
    'early_stopping_round':5
}

#================================================================================
# 三,训练模型
#================================================================================
printlog("step3: training model...")

result = {}

vlog_cb = VLogCallback(boost_round, monitor_metric = 'val_auc', monitor_mode = 'max')
vlog_cb.vlog.log_start()

gbm = lgb.train(params,
                lgb_train,
                num_boost_round= boost_round,
                valid_sets=(lgb_valid, lgb_train),
                valid_names=('val','train'),
                callbacks = [lgb.record_evaluation(result),
                             vlog_cb]
               )

vlog_cb.vlog.log_end()

#================================================================================
# 四,评估模型
#================================================================================
printlog("step4: evaluating model ...")

y_pred_train = gbm.predict(dftrain.drop('label',axis = 1), num_iteration=gbm.best_iteration)
y_pred_test = gbm.predict(dftest.drop('label',axis = 1), num_iteration=gbm.best_iteration)

print('train accuracy: {:.5} '.format(accuracy_score(dftrain['label'],y_pred_train>0.5)))
print('valid accuracy: {:.5} \n'.format(accuracy_score(dftest['label'],y_pred_test>0.5)))


#================================================================================
# 五,保存模型
#================================================================================
printlog("step5: saving model ...")


model_dir = "gbm.model"
print("model_dir: %s"%model_dir)
gbm.save_model("gbm.model")
printlog("task end...")

###
##
#
================================================================================2023-11-10 15:39:38
step1: reading data......



================================================================================2023-11-10 15:39:38
step2: setting parameters......



================================================================================2023-11-10 15:39:38
step3: training model......

6c24936c42819347399d1cfea07826d9.png

================================================================================2023-11-10 15:39:44
step4: evaluating model ......


train accuracy: 0.95775 
valid accuracy: 0.94406 


================================================================================2023-11-10 15:39:44
step5: saving model ......


model_dir: gbm.model

================================================================================2023-11-10 15:39:44
task end......

三, 在ultralytics中使用VLog

写个适配的回调~

ultralytics可以做 分类,检测,分割 等等。

这个回调函数是通用的,此处以分类问题为例,改个monitor_metric即可~

cats_vs_dogs数据集可以在公众号算法美食屋后台回复:torchkeras 获取~

from torchkeras import VLog
class VLogCallback:
    def __init__(self,epochs,monitor_metric,monitor_mode):
        self.vlog = VLog(epochs,monitor_metric,monitor_mode)
        
    def on_train_batch_end(self,trainer):
        self.vlog.log_step(trainer.label_loss_items(trainer.tloss, prefix='train'))

    def on_fit_epoch_end(self,trainer):
        metrics = {k.split('/')[-1]:v for k,v in trainer.metrics.items() if 'loss' not in k}
        self.vlog.log_epoch(metrics)

    def on_train_epoch_end(self,trainer):
        pass
from ultralytics import YOLO 
epochs = 10

vlog_cb = VLogCallback(epochs = epochs,
                       monitor_metric='accuracy_top1',
                       monitor_mode='max')
callbacks = {
    "on_train_batch_end": vlog_cb.on_train_batch_end,
    "on_fit_epoch_end": vlog_cb.on_fit_epoch_end
}

model = YOLO(model = 'yolov8n-cls.pt')
for event,func in callbacks.items():
    model.add_callback(event,func)
    
vlog_cb.vlog.log_start()
results = model.train(data='cats_vs_dogs', 
                      epochs=epochs, workers=4)     # train the model
vlog_cb.vlog.log_end()

8aa62f4f707a58de1003cf4e4eb9daef.png

四, 在transformers中使用VLog

waimai评论数据集可以在公众号算法美食屋后台回复:torchkeras 获取~

#回调给你写好了~
from torchkeras.tools.transformers import VLogCallback
import numpy as np 
import pandas as pd 
import torch 
import datasets 
from transformers import AutoTokenizer,DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification 
from transformers import TrainingArguments,Trainer 
from transformers import EarlyStoppingCallback

from tqdm import tqdm 
from transformers import AdamW, get_scheduler


#一,准备数据

df = pd.read_csv("waimai_10k.csv")
ds = datasets.Dataset.from_pandas(df)
ds = ds.shuffle(42) 
ds = ds.rename_columns({"review":"text","label":"labels"})

tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese') 

ds_encoded = ds.map(lambda example:tokenizer(example["text"]),
                    remove_columns = ["text"],
                    batched=True)

#train,val,test split
ds_train_val,ds_test = ds_encoded.train_test_split(test_size=0.2).values()
ds_train,ds_val = ds_train_val.train_test_split(test_size=0.2).values() 

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=16, collate_fn = data_collator)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=16,  collate_fn = data_collator)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=16,  collate_fn = data_collator)

for batch in dl_train:
    break
print({k: v.shape for k, v in batch.items()})



#二,定义模型
model = AutoModelForSequenceClassification.from_pretrained(
    'bert-base-chinese',num_labels=2)

#三,训练模型
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    accuracy = np.sum(preds==labels)/len(labels)
    precision = np.sum((preds==1)&(labels==1))/np.sum(preds==1)
    recall = np.sum((preds==1)&(labels==1))/np.sum(labels==1)
    f1  = 2*recall*precision/(recall+precision)
    return {"accuracy":accuracy,"precision":precision,"recall":recall,'f1':f1}

training_args = TrainingArguments(
    output_dir = "bert_waimai",
    num_train_epochs = 3,
    logging_steps = 20,
    gradient_accumulation_steps = 10,
    evaluation_strategy="steps", #epoch
    
    metric_for_best_model='eval_f1',
    greater_is_better=True,
    
    report_to='none',
    load_best_model_at_end=True
)

callbacks = [EarlyStoppingCallback(early_stopping_patience=10),
             VLogCallback()] #监控指标同 metric_for_best_model

trainer = Trainer(
    model,
    training_args,
    train_dataset=ds_train,
    eval_dataset=ds_val,
    compute_metrics=compute_metrics,
    callbacks = callbacks,
    data_collator=data_collator,
    tokenizer=tokenizer,
)
trainer.train() 



#四,评估模型
trainer.evaluate(ds_val)


#五,使用模型
from transformers import pipeline
model.config.id2label = {0:"差评",1:"好评"}
classifier = pipeline(task="text-classification",tokenizer = tokenizer,model=model.cpu())
classifier("挺好吃的哦")

#六,保存模型
model.save_pretrained("waimai_10k_bert")
tokenizer.save_pretrained("waimai_10k_bert")

classifier = pipeline("text-classification",model="waimai_10k_bert")
classifier(["味道还不错,下次再来","我去,吃了我吐了三天"])

b9371ef095ed42846e9cddf7bc6e9df1.png

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook代码和数据集下载链接~

5b939b3ca198ae4f5a83d2b6f9dcc79b.png

万水千山总是情,点个赞赞行不行~😋

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1198477.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux的make和Makefile

目录 一、 介绍二、快速使用三、依赖关系和依赖方法四、语法 一、 介绍 1、makefile带来的好处就是——“自动化编译”,一旦写好,只需要一个make命令,整个工程完全自动编译,极大的提高了软件开发的效率。 2、make是一个命令工具&…

dcat admin 各种问题

样式问题 如何根据条件给表格数据栏添加背景色 use Illuminate\Support\Collection;protected function grid(){return Grid::make(new BookArticle(), function (Grid $grid) {... 其他代码// Collection的完整路径:Illuminate\Support\Collection;$grid->row…

火星加载WMTS服务

这是正常的加载瓦片 http://192.168.1.23:8008/geoserver/mars3d/gwc/service/wmts?tilematrixEPSG%3A4326%3A7&layermars3d%3Abuffer&style&tilerow46&tilecol197&tilematrixsetEPSG%3A4326&formatimage%2Fpng&serviceWMTS&version1.0.0&…

超详细介绍对极几何和立体视觉及 Python 和 C++实现

您是否想过为什么戴着特殊的 3D 眼镜观看电影时可以体验到美妙的 3D 效果?或者为什么闭上一只眼睛很难接住板球?这一切都与立体视觉有关,立体视觉是我们用双眼感知深度的能力。这篇文章使用 OpenCV 和立体视觉为计算机提供这种感知深度的能力。代码以 Python 和 C++ 形式提供…

迷雾系统-1 地图及其区块

创建UGUI地图,每块地块(Image)上添加AreaNode脚本,根据PolygonCollider2D可视化编辑碰撞体形状,并以此生成Mesh Mc_AreaNode脚本: private GameObject _objPrefab; //创建的Mesh预制体private float _canvasPosZ;pr…

[N-133]基于springboot,vue小说网站

开发工具:IDEA 服务器:Tomcat9.0, jdk1.8 项目构建:maven 数据库:mysql5.7 系统分前后台,项目采用前后端分离 前端技术:vueelementUI 服务端技术:springbootmybatis-plus 本项…

【Java面向对象编程(中)】- 探索封装的秘密

🌈个人主页: Aileen_0v0🔥系列专栏:Java学习系列专栏💫个人格言:"没有罗马,那就自己创造罗马~" 目录 回顾 封装​编辑 为什么进行封装 ​​编辑​ 如何调用私有的变量 ​​编辑​ 1.get set方法(当形参和成员变量不同名时)​…

LeetCode | 138. 随机链表的复制

LeetCode | 138. 随机链表的复制 OJ链接 思路: 题目要求我们拷贝一个带next指针与random随机访问指针的链表。 如果只拷贝一个只带next的指针,直接遍历目标链表依次拷贝每个节点的信息就可以了~~ 拷贝节点插入到原节点的后面处理copy节点的randomcop…

Leetcode—103.二叉树的锯齿形层序遍历【中等】

2023每日刷题(二十六) Leetcode—103.二叉树的锯齿形层序遍历 BFS实现代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ /*** Return an array of ar…

138.随机链表的复制(LeetCode)

深拷贝,是指将该链表除了正常单链表的数值和next指针拷贝,再将random指针进行拷贝 想法一 先拷贝出一份链表,再对于每个节点的random指针,在原链表进行遍历,找到random指针的指向,最后完成拷贝链表random…

第一百六十八回 NavigationBar组件

文章目录 1. 概念介绍2. 使用方法3. 代码与效果3.1 示例代码3.2 运行效果 4. 内容总结 我们在上一章回中介绍了"如何修改按钮的形状"相关的内容,本章回中将 介绍NavigationBar组件.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 我们在本…

Linux驱动开发——PCI设备驱动

目录 一、 PCI协议简介 二、PCI和PCI-e 三、Linux PCI驱动 四、 PCI设备驱动实例 五、 总线类设备驱动开发习题 一、 PCI协议简介 PCI (Peripheral Component Interconnect,外设部件互联) 局部总线是由Intel 公司联合其他几家公司一起开发的一种总线标准&#…

初识-Servlet (第一个 Servlet 程序详解)

Servlet 是什么? Servlet 是一种实现动态页面的技术. 是一组 Tomcat 提供给程序员的 API, 帮助程序员简单高效的开发一个 web app. 静态页面就只是单纯的 html 动态页面则是 html 数据 第一个 Servlet 程序 我们写一个 hello world 预期写一个 Servlet 程序, 部署到 Tomca…

图论12-无向带权图及实现

文章目录 带权图1.1带权图的实现1.2 完整代码 带权图 1.1带权图的实现 在无向无权图的基础上,增加边的权。 使用TreeMap存储边的权重。 遍历输入文件,创建TreeMap adj存储每个节点。每个输入的adj节点链接新的TreeMap,存储相邻的边和权重 …

时间序列预测实战(十二)DLinear模型实现滚动长期预测并可视化预测结果

官方论文地址->官方论文地址 官方代码地址->官方代码地址 个人修改代码->个人修改的代码已经上传CSDN免费下载 一、本文介绍 本文给大家带来是DLinear模型,DLinear是一种用于时间序列预测(TSF)的简单架构,DLinear的核…

Leetcode刷题详解—— 目标和

1. 题目链接:494. 目标和 2. 题目描述: 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - ,然后串联起所有整数,可以构造一个 表达式 : 例如,nums [2, 1] ,可…

【计算机网络笔记】IP分片

系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…

有没有实时检测微信聊天图片的软件,只要微信收到了有二维码的图片就把它提取出来?

10-2 如果你有需要自动并且快速地把微信收到的二维码图片保存到指定文件夹的需求,那本文章非常适合你,本文章教你如何实现自动保存微信收到的二维码图片到你指定的文件夹中,助你快速扫码,比别人领先一步。 首先需要准备好的材料…

19 异步通知

一、异步通知 1. 异步通知简介 阻塞和非阻塞两种方式都是需要应用程序去主动查询设备的使用情况。 异步通知类似于驱动可以主动报告自己可以访问,应用程序获取信号后会从驱动设备中读取或写入数据。 异步通知最核心的就是信号: #define SIGHUP 1 /* 终…

openssl研发之base64编解码实例

一、base64编码介绍 Base64编码是一种将二进制数据转换成ASCII字符的编码方式。它主要用于在文本协议中传输二进制数据,例如电子邮件的附件、XML文档、JSON数据等。 Base64编码的特点如下: 字符集: Base64编码使用64个字符来表示二进制数据…