如何在深度学习中调用CAME

news2024/12/23 19:56:53

1、介绍

CAME:一种以置信度为导向的策略,以减少现有内存高效优化器的不稳定性。基于此策略,我们提出CAME同时实现两个目标:传统自适应方法的快速收敛和内存高效方法的低内存使用。大量的实验证明了CAME在各种NLP任务(如BERT和GPT-2训练)中的训练稳定性和优异的性能。

2、Pytorch中调用该优化算法

(1)定义CAME

import math

import torch
import torch.optim

class CAME(torch.optim.Optimizer):
    """Implements CAME algorithm.
    This implementation is based on:
    `CAME: Confidence-guided Adaptive Memory Efficient Optimization`
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): external learning rate (default: None)
        eps (tuple[float, float]): regularization constants for square gradient
            and instability respectively (default: (1e-30, 1e-16))
        clip_threshold (float): threshold of root-mean-square of
            final gradient update (default: 1.0)
        betas (tuple[float, float, float]): coefficient used for computing running averages of
        update, square gradient and instability (default: (0.9, 0.999, 0.9999)))
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
    """

    def __init__(
        self,
        params,
        lr=None,
        eps=(1e-30, 1e-16),
        clip_threshold=1.0,
        betas=(0.9, 0.999, 0.9999),
        weight_decay=0.0,
    ):
        assert lr > 0.
        assert all([0. <= beta <= 1. for beta in betas])

        defaults = dict(
            lr=lr,
            eps=eps,
            clip_threshold=clip_threshold,
            betas=betas,
            weight_decay=weight_decay,
        )
        super(CAME, self).__init__(params, defaults)

    @property
    def supports_memory_efficient_fp16(self):
        return True

    @property
    def supports_flat_params(self):
        return False


    def _get_options(self, param_shape):
        factored = len(param_shape) >= 2
        return factored

    def _rms(self, tensor):
        return tensor.norm(2) / (tensor.numel() ** 0.5)

    def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
        r_factor = (
            (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
            .rsqrt_()
            .unsqueeze(-1)
        )
        c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        return torch.mul(r_factor, c_factor)

    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.dtype in {torch.float16, torch.bfloat16}:
                    grad = grad.float()
                if grad.is_sparse:
                    raise RuntimeError("CAME does not support sparse gradients.")

                state = self.state[p]
                grad_shape = grad.shape

                factored = self._get_options(grad_shape)
                # State Initialization
                if len(state) == 0:
                    state["step"] = 0

                    state["exp_avg"] = torch.zeros_like(grad)
                    if factored:
                        state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).type_as(grad)
                        state["exp_avg_sq_col"] = torch.zeros(
                            grad_shape[:-2] + grad_shape[-1:]
                        ).type_as(grad)

                        state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad)
                        state["exp_avg_res_col"] = torch.zeros(
                            grad_shape[:-2] + grad_shape[-1:]
                        ).type_as(grad)
                    else:
                        state["exp_avg_sq"] = torch.zeros_like(grad)

                    state["RMS"] = 0

                state["step"] += 1
                state["RMS"] = self._rms(p.data)

                update = (grad**2) + group["eps"][0]
                if factored:
                    exp_avg_sq_row = state["exp_avg_sq_row"]
                    exp_avg_sq_col = state["exp_avg_sq_col"]

                    exp_avg_sq_row.mul_(group["betas"][1]).add_(
                        update.mean(dim=-1), alpha=1.0 - group["betas"][1]
                    )
                    exp_avg_sq_col.mul_(group["betas"][1]).add_(
                        update.mean(dim=-2), alpha=1.0 - group["betas"][1]
                    )

                    # Approximation of exponential moving average of square of gradient
                    update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
                    update.mul_(grad)
                else:
                    exp_avg_sq = state["exp_avg_sq"]

                    exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])
                    update = exp_avg_sq.rsqrt().mul_(grad)

                update.div_(
                    (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)
                )

                exp_avg = state["exp_avg"]
                exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])

                # Confidence-guided strategy
                # Calculation of instability
                res = (update - exp_avg)**2 + group["eps"][1]

                if factored:
                    exp_avg_res_row = state["exp_avg_res_row"]
                    exp_avg_res_col = state["exp_avg_res_col"]

                    exp_avg_res_row.mul_(group["betas"][2]).add_(
                        res.mean(dim=-1), alpha=1.0 - group["betas"][2]
                    )
                    exp_avg_res_col.mul_(group["betas"][2]).add_(
                        res.mean(dim=-2), alpha=1.0 - group["betas"][2]
                    )

                    # Approximation of exponential moving average of instability
                    res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)
                    update = res_approx.mul_(exp_avg)
                else:
                    update = exp_avg

                if group["weight_decay"] != 0:
                    p.data.add_(
                            p.data, alpha=-group["weight_decay"] * group["lr"]
                        )

                update.mul_(group["lr"])
                p.data.add_(-update)

        return loss

(2)在深度学习中调用CAME优化器

本文以使用LSTM算法对鸢尾花数据集进行分类为例,并且在代码中加入早停和十折交叉验证技术。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

# 定义 LSTM 模型
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        _, (hn, _) = self.lstm(x)
        out = self.fc(hn[-1])  # 选择最后一个 LSTM 隐层输出
        return out

# 早停
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def step(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# 读取数据
iris = load_iris()
X = iris.data
y = iris.target

# 标准化数据
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 将数据转换为 PyTorch 张量
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# 配置模型参数
input_size = X.shape[1]  # 特征数量
hidden_size = 32
num_classes = 3
batch_size = 16
num_epochs = 100
learning_rate = 0.001
patience = 5

# 进行十折交叉验证
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
fold_idx = 0

for train_index, val_index in kf.split(X, y):
    fold_idx += 1
    print(f"Fold {fold_idx}")

    # 划分训练集和验证集
    X_train, X_val = X[train_index], X[val_index]
    y_train, y_val = y[train_index], y[val_index]

    # 定义模型和优化器
    model = LSTMClassifier(input_size, hidden_size, num_classes)
    optimizer = CAME(model.parameters(), lr=2e-4, weight_decay=1e-2, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # 早停设置
    early_stopping = EarlyStopping(patience=patience)

    # 训练模型
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train.unsqueeze(1))
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()
        
        # 验证阶段
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val.unsqueeze(1))
            val_loss = criterion(val_outputs, y_val)
        
        # 打印每轮迭代的损失值
        print(f"Epoch {epoch + 1}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss.item():.4f}")

        # 早停检查
        early_stopping.step(val_loss.item())
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
    # 评估模型
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val.unsqueeze(1))
        _, predicted = torch.max(val_outputs, 1)
        accuracy = accuracy_score(y_val, predicted)
        print(f"Fold {fold_idx} Validation Accuracy: {accuracy:.4f}\n")

使用LSTM算法对鸢尾花数据集分类结果
由于CAME主要面向NLP数据集,因此对于鸢尾花效果不算好,本文仅展示CAME的使用方法,并非提升acc和epoch。

参考文献:Luo, Yang, et al. “CAME: Confidence-guided Adaptive Memory Efficient Optimization.” arXiv preprint arXiv:2307.02047 (2023).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607930.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【python】直接在python3下安装 jupyter notebook,以及处理安装报错,启动不了问题

目录 问题&#xff1a; 1 先做准备&#xff0c;查看环境 1.1 先看python3 和pip &#xff0c;以及查看是否有 juypter 1.2 开始安装 1.3 安装完成后得到警告和报错 2 处理安装的报错问题 2.1 网上有说是因为 pip 自身需要更新&#xff0c;更新之 2.1.1 更新pip 2.1.…

vue快速入门(三十二)局部与全局注册组件的步骤

注释很详细&#xff0c;直接上代码 上一篇 新增内容 局部注册组件全局注册组件 文件结构 源码 MyHeader.vue <!-- 用于测试全局注册组件 --> <template><div><h1>又可以愉快的学习啦</h1></div> </template><script>export d…

开启Three.js之旅(会持续完善)

文章目录 Three.js必备构建项目场景Scene相机CameraPerspectiveCamera 渲染器WebGLRendererCSS3DRenderer 灯光LightAmbientLightDirectionalLight 平行光PointLight 加载器CacheFileLoaderLoaderGLTFLoaderRGBELoaderTextureLoader 材质MetarialMeshBasicMaterialMeshLambertM…

武汉星起航:上海股权中心成功挂牌,创始人张振邦领航跨境新纪元

在金秋十月的尾声&#xff0c;上海股权托管交易中心迎来了一场备受瞩目的盛事。2023年10月30日&#xff0c;武汉星起航电子商务有限公司成功挂牌展示&#xff0c;正式登录资本市场&#xff0c;开启了一段崭新的发展篇章。这一里程碑式的跨越&#xff0c;不仅标志着武汉星起航在…

MySQL基础-----约束详解

目录 一. 概述: 二.约束演示&#xff1a; 三.外键约束&#xff1a; 3.1介绍&#xff1a; 3.2外键约束语法&#xff1a; 3.3删除&#xff0c;更新行为&#xff1a; 一. 概述: &#x1f9d0;&#x1f9d0;概念&#xff1a;约束是作用于表中字段上的规则&#xff0c;用于限制…

【机器学习】数据变换---小波变换特征提取及应用案列介绍

引言 在机器学习领域&#xff0c;数据变换是一种常见且重要的预处理步骤。通过对原始数据进行变换&#xff0c;我们可以提取出更有意义的特征&#xff0c;提高模型的性能。在众多数据变换方法中&#xff0c;小波变换是一种非常有效的方法&#xff0c;尤其适用于处理非平稳信号和…

会话seesion的使用,结合ddddocr识别简单验证码的登录实现。

古诗文网登录代码&#xff1a; # 古诗文网登录实战 # 验证码链接:https://so.gushiwen.cn/RandCode.ashx # 变动参数链接:__VIEWSTATE所在的地址:https://so.gushiwen.cn/user/login.aspx?fromhttp://so.gushiwen.cn/user/collect.aspx # 登录接口链接:https://so.gushiwen.c…

npx\pnpm 镜像过期解决方法

. // 1. 清空缓存 npm cache clean --force // 2. 关闭SSL验证 npm config set strict-ssl false // 3. 安装 到这里就可以正常使用npm命令安装需要的工具了。如( npm install -g cnpm )

代码学习记录48---单调栈

随想录日记part48 t i m e &#xff1a; time&#xff1a; time&#xff1a; 2024.04.19 主要内容&#xff1a;今天开始要学习单调栈的相关知识了&#xff0c;今天的内容主要涉及&#xff1a;503.下一个更大元素II ;42. 接雨水 503.下一个更大元素II 42. 接雨水 Topic1下一个更…

第二部分 Python提高—GUI图形用户界面编程(六)

其他组件学习 文章目录 OptionMenu 选择项Scale 移动滑块颜色选择框文件对话框简单输入对话框通用消息框ttk 子模块控件 OptionMenu 选择项 OptionMenu(选择项)用来做多选一&#xff0c;选中的项在顶部显示。显示效果如下&#xff1a; from tkinter import * root Tk();ro…

电弧的产生机理

目录&#xff1a; 1、起弧机理 2、电弧特点 3、电弧放电特点 4、实际意义 1&#xff09;电力开关装置 2&#xff09;保险丝 1、起弧机理 电弧的本质是一种气体放电现象&#xff0c;可以理解为绝缘情况下产生的高强度瞬时电流。起弧效果如下图所示&#xff1a; 在电场的…

5 CatBoost模型

目录 1 背景 2 原理 2.1 类别特征处理 2.1.1 传统目标编码&#xff1a; TS 2.1.2 Greedy TS 2.1.3 ordered TS编码 2.1.4 CatBoost处理Categorical features总结 2.2.预测偏移处理 2.2.1 梯度无偏估计 2.3 树的构建​​​​​​​ 3 优缺点 优点 4 代码 1 背景 终于…

关系图卷积神经网络

异质图和知识图谱 同质图与异质图 同质图指的是图中的节点类型和关系类型都仅有一种 异质图是指图中的节点类型或关系类型多于一种 知识图谱 知识图谱包含实体和实体之间的关系&#xff0c;并以三元组的形式存储&#xff08;<头实体, 关系, 尾实体>&#xff0c;即异…

Three.js——聚光灯、环境光、点光源、平行光、半球光

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

阿里云服务器怎么更换暴露的IP

很多客户阿里云服务器被攻击IP暴露&#xff0c;又不想迁移数据换服务器&#xff0c;其实阿里云服务器可以更换IP&#xff0c;今天就来和大家说说流程&#xff0c;云服务器创建成功后6小时内可以免费更换公网IP地址三次&#xff0c;超过6小时候就只能通过换绑弹性公网IP的方式来…

探索人工智能绘图的奇妙世界

探索人工智能绘图的奇妙世界 人工智能绘图的基本原理机器之美&#xff1a;AI绘图作品AI绘图对艺术创作的影响未来展望与挑战图书推荐&#x1f449;AI绘画教程&#xff1a;Midjourney使用方法与技巧从入门到精通内容简介获取方式&#x1f449;搜索之道&#xff1a;信息素养与终身…

访问云平台中linux系统图形化界面,登录就出现黑屏的问题解决(ubuntu图形界面)

目录 一、问题-图形化界面访问黑屏 二、系统环境 &#xff08;一&#xff09;网络结构示意图 &#xff08;二&#xff09;内部机器版本 三、分析 四、解决过程 &#xff08;一&#xff09;通过MobaXterm远程访问图形化界面(未成功) 1、连接方法 2、连接结果 &#xf…

【新版】系统架构设计师 - 知识点 - 结构化开发方法

个人总结&#xff0c;仅供参考&#xff0c;欢迎加好友一起讨论 文章目录 架构 - 知识点 - 结构化开发方法结构化开发方法结构化分析结构化设计 数据流图和数据字典模块内聚类型与耦合类型 架构 - 知识点 - 结构化开发方法 结构化开发方法 分析阶段 工具&#xff1a;数据流图、…

VUE项目使用.env配置多种环境以及如何加载环境

第一步&#xff0c;创建多个环境配置文件 Vue CLI 项目默认使用 .env 文件来定义环境变量。你可以通过创建不同的 .env 文件来为不同环境设置不同的环境变量&#xff0c;例如&#xff1a; .env —— 所有模式共用.env.local —— 所有模式共用&#xff0c;但不会被 git 提交&…

算法模板-线段树+懒标记

视频连接&#xff1a;C02【模板】线段树懒标记 Luogu P3372 线段树 1_哔哩哔哩_bilibili 题目链接&#xff1a;P3372 【模板】线段树 1 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) P3374 【模板】树状数组 1 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 算法思路 递…