optuna和 lightgbm

news2024/12/31 5:38:46

文章目录

  • optuna使用
    • 1.导入相关包
    • 2.定义模型可选参数
    • 3.定义训练代码和评估代码
    • 4.定义目标函数
    • 5.运行程序
    • 6.可视化
    • 7.超参数的重要性
    • 8.查看相关信息
    • 9.可视化的一个完整示例
    • 10.lightgbm实验

optuna使用

1.导入相关包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fvcore.nn import FlopCountAnalysis

import optuna


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30   # 128 * 30个训练
N_VALID_EXAMPLES = BATCHSIZE * 10   # 128 * 10个预测

2.定义模型可选参数

optuna支持很多种搜索方式:
(1)trial.suggest_categorical(‘optimizer’, [‘MomentumSGD’, ‘Adam’]):表示从SGD和adam里选一个使用;
(2)trial.suggest_int(‘num_layers’, 1, 3):从1~3范围内的int里选;
(3)trial.suggest_uniform(‘dropout_rate’, 0.0, 1.0):从0~1内的uniform分布里选;
(4)trial.suggest_loguniform(‘learning_rate’, 1e-5, 1e-2):从1e-5~1e-2的log uniform分布里选;
(5)trial.suggest_discrete_uniform(‘drop_path_rate’, 0.0, 1.0, 0.1):从0~1且step为0.1的离散uniform分布里选;

def define_model(trial):
    n_layers = trial.suggest_int("n_layers", 1, 3) # 从[1,3]范围里面选一个
    layers = []

    in_features = 28 * 28
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features

    layers.append(nn.Linear(in_features, 10))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers)

3.定义训练代码和评估代码

# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        F.nll_loss(model(data), target).backward()
        optimizer.step()


def eval_model(model, valid_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(valid_loader):
            data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
            pred = model(data).argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / N_VALID_EXAMPLES

    flops = FlopCountAnalysis(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),)).total()
    return flops, accuracy

4.定义目标函数

def objective(trial):
    train_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    val_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=False, transform=torchvision.transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    model = define_model(trial).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    )

    for epoch in range(10):
        train_model(model, optimizer, train_loader)
    flops, accuracy = eval_model(model, val_loader)
    return flops, accuracy

5.运行程序

运行30次实验,每次实验返回 flops,accuracy

study = optuna.create_study(directions=["minimize", "maximize"]) # flops 最小化, accuracy 最大化
study.optimize(objective, n_trials=30, timeout=300)

print("Number of finished trials: ", len(study.trials))

6.可视化

flops, accuracy 二维图
optuna.visualization.plot_pareto_front(study, target_names=[“FLOPS”, “accuracy”])

在这里插入图片描述

7.超参数的重要性

对于flops
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[0], target_name=“flops”
)

对于accuracy
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[1], target_name=“accuracy”
)

在这里插入图片描述

8.查看相关信息

# https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/002_multi_objective.html
# 利用pytorch mnist 识别
# 设置了一些超参数,lr, layer number, feature_number等
# 然后目标是 flops 和 accurary

# 最后是可视化:
# 显示试验的一些结果:
# optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])
# 左上角是最好的

# 显示重要性:
# optuna.visualization.plot_param_importances(
#     study, target=lambda t: t.values[0], target_name="flops"
# )
# optuna.visualization.plot_param_importances(
#     study, target=lambda t: t.values[1], target_name="accuracy"
# )


# trials的属性:
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")

trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print(f"Trial with highest accuracy: ")
print(f"\tnumber: {trial_with_highest_accuracy.number}")
print(f"\tparams: {trial_with_highest_accuracy.params}")
print(f"\tvalues: {trial_with_highest_accuracy.values}")

9.可视化的一个完整示例

# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_rank
from optuna.visualization import plot_slice
from optuna.visualization import plot_timeline

def objective(trial):
    train_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    val_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=False, transform=torchvision.transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    model = define_model(trial).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    )

    for epoch in range(10):
        train_model(model, optimizer, train_loader)

        val_accuracy = eval_model(model, val_loader)
        trial.report(val_accuracy, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return val_accuracy

study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=30, timeout=300)

运行之后可视化:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

10.lightgbm实验

"""
Optuna example that optimizes a classifier configuration for cancer dataset using LightGBM.

In this example, we optimize the validation accuracy of cancer detection using LightGBM.
We optimize both the choice of booster model and their hyperparameters.

"""

import numpy as np
import optuna

import lightgbm as lgb
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split


# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
def objective(trial):
    data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
    train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)
    dtrain = lgb.Dataset(train_x, label=train_y)

    param = {
        "objective": "binary",
        "metric": "binary_logloss",
        "verbosity": -1,
        "boosting_type": "gbdt",
        "lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),
        "lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
        "num_leaves": trial.suggest_int("num_leaves", 2, 256),
        "feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0),
        "bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
        "bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
        "min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
    }

    gbm = lgb.train(param, dtrain)
    preds = gbm.predict(valid_x)
    pred_labels = np.rint(preds)
    accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels)
    return accuracy


if __name__ == "__main__":
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=100)

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: {}".format(trial.value))

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

运行结果:
在这里插入图片描述

https://github.com/microsoft/LightGBM/tree/master/examples

https://blog.csdn.net/yang1015661763/article/details/131364826

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2267447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SD ComfyUI工作流 对人物图像进行抠图并替换背景

文章目录 人物抠图与换背景SD模型Node节点工作流程工作流下载效果展示人物抠图与换背景 此工作流旨在通过深度学习模型完成精确的人物抠图及背景替换操作。整个流程包括图像加载、遮罩生成、抠图处理、背景替换以及最终的图像优化。其核心基于 SAM(Segment Anything Model)与…

【C语言程序设计——循环程序设计】利用循环求数值 x 的平方根(头歌实践教学平台习题)【合集】

目录😋 任务描述 相关知识 一、求平方根的迭代公式 1. 原理 2. 代码实现示例 二、绝对值函数fabs() 1. 函数介绍 2. 代码示例 三、循环语句 1. for循环 2. while循环 3. do - while循环 编程要求 测试说明 通关代码 测试结果 任务描述 本关任务&…

程序猿成长之路之设计模式篇——结构型设计模式

本篇开始介绍结构型设计模式 前言 与创建型设计模式用于创建对象不同,结构型设计模式通过结构化的方式实现功能的扩展和解耦,通过对象的组合、聚合、继承和接口等机制来定义对象之间的关系,从而实现松耦合和灵活性。 常见的结构性设计模式&…

低代码开源项目Joget的研究——Joget8社区版安装部署

大纲 环境准备安装必要软件配置Java配置JAVA_HOME配置Java软链安装三方库 获取源码配置MySql数据库创建用户创建数据库导入初始数据 配置数据库连接配置sessionFactory(非必须,如果后续保存再配置)编译下载tomcat启动下载aspectjweaver移动jw…

数据库的概念和操作

目录 1、数据库的概念和操作 1.1 物理数据库 1. SQL SERVER 2014的三种文件类型 2. 数据库文件组 1.2 逻辑数据库 2、数据库的操作 2.1 T-SQL的语法格式 2.2 创建数据库 2.3 修改数据库 2.4 删除数据库 3、数据库的附加和分离 1、数据库的概念和操作 1.1 物理数据库…

React中最优雅的异步请求

给大家分享在React19中使用useSuspense处理异步请求为什么是被认为最优雅的解决方案 一. 传统方案 解决异步请求的方案中,我们要处理至少两个最基本的逻辑 正常的数据显示数据加载的UI状态 例如: export default function Index(){const [content, …

基于Bregman的交替方向乘子法

目录标题 ADMM方法简介Bregman散度Bregman ADMM的原理主要优势代码示例:各个符号的解释:**梯度的几何含义**:具体数学公式:**应用示例**:**ADMM的标准形式:****ADMM中的变量角色:****ADMM中的更…

`we_chat_union_id IS NOT NULL` 和 `we_chat_union_id != ‘‘` 这两个条件之间的区别

文章目录 1、什么是空字符串?2、两个引号之间加上空格 好的,我们来详细解释一下 we_chat_union_id IS NOT NULL 和 we_chat_union_id ! 这两个条件之间的区别,以及它们在 SQL 查询中的作用: 1. we_chat_union_id IS NOT NULL 含…

随机变量是一个函数-如何理解

文章目录 一. 随机变量二. 随机变量是一个函数-栗子(一对一)1. 掷骰子的随机变量2. 掷骰子的随机变量(求点数平方)3. 抛硬币的随机变量4. 学生考试得分的随机变量 三. 随机变量是一个函数-理解(多对一) 一. 随机变量 随机变量就是定义在样本空间上的函数…

jwt在express中token的加密解密实现方法

在我们前面学习了 JWT认证机制在Node.js中的详细阐述 之后,今天来详细学习一下token是如何生成的,secret密钥的加密解密过程是怎么样的。 安装依赖 express:用于创建服务器jsonwebtoken:用于生成和验证JWTbody-parser&#xff1…

大厂开发规范-如何规范的提交Git

多人协作开发提交代码通常是遵循约定式提交规范,如果严格安照约定式提交规范, 手动进行代码提交的话,那么是一件非常痛苦的事情,但是 Git 提交规范的处理又势在必行,那么怎么办呢? 经过了很多人的冥思苦想…

企业安装加密软件有什么好处?

加密软件为企业的安全提供了很多便利,从以下几点我们看看比较重要的几个优点: 1、数据保护:企业通常拥有大量的商业机密、客户数据、技术文档等敏感信息。加密软件可以对这些信息进行加密处理,防止未经授权的人员访问。即使数据被…

【ANGULAR网站开发】初始环境搭建

1. 初始化angular项目 1.1 创建angular项目 需要安装npm和nodejs,这边不在重新安装 直接安装最新版本的angular npm install -g angular/cli安装指定大版本的angular npm install -g angular/cli181.2 启动angular 使用idea启动 控制台启动 ng serve启动成功…

Python 屏幕取色工具

Python 屏幕取色工具 1.简介: 屏幕取色小工具‌是一种实用的软件工具,主要用于从屏幕上精确获取颜色值,非常适合设计、编程等需要精确配色的领域。这类工具通常能够从屏幕上任何区域精确提取颜色值,支持在整数值、RGB值、BGR值之…

宏集eX710物联网工控屏在石油开采机械中的应用与优势

案例概况 客户:天津某石油机械公司 应用产品:宏集eX710物联网工控屏 应用场景:钻井平台设备控制系统 一、应用背景 石油开采和生产过程复杂,涵盖钻井平台、采油设备、压缩机、分离器、管道输送系统等多种机械设备。这些设备通…

实验室服务器Ubuntu安装使用全流程

一、制作U盘启动盘 工具: 一个32G以上的U盘Rufuse镜像烧录软件下载:https://cn.ultraiso.net/xiazai.htmlRufus - 轻松创建 USB 启动盘https://cn.ultraiso.net/xiazai.htmlUbuntu系统镜像:https://ubuntu.com/download/alternative-downlo…

2-198基于Matlab-GUI的运动物体追击问题

基于Matlab-GUI的运动物体追击问题,定义目标航速、航线方向、鱼雷速度,并设置目标和鱼雷初始位置,根据航速和航向优化鱼雷路径,实现精准打击。程序已调通,可直接运行。 2-198基于Matlab-GUI的运动物体追击问题

实验五 时序逻辑电路部件实验

一、实验目的 熟悉常用的时序逻辑电路功能部件,掌握计数器、了解寄存器的功能。 二、实验所用器件和仪表 1、双 D触发器 74LS74 2片 2、74LS162 1片 3、74194 1片 4、LH-D4实验仪 1台 1.双…

UnityURP 自定义PostProcess之深度图应用

UnityURP 自定义PostProcess之深度图 前言项目Shader代码获取深度图ASE连线获取深度图 前言 在Unity URP中利用深度图可以实现以下两种简单的效果,其他设置参考 UnityURP 自定义PostProcess 项目 Shader代码获取深度图 Shader "CustomPost/URPScreenTintSha…

PlasmidFinder:质粒复制子的鉴定和分型

质粒(Plasmid)是一种细菌染色体外的线性或环状DNA分子,也是一种重要的遗传元素,它们具有自主复制能力,可以在细菌之间传播,并携带多种重要的基因(如耐药基因与毒力基因等)功能。根据质粒传播的特性&#xf…