Code for VeLO 1: Training Versatile Learned Optimizers by Scaling Up

news2025/1/22 15:57:54

Code for VeLO 1: Training Versatile Learned Optimizers by Scaling Up

这篇文章将介绍一下怎么用VeLO进行训练。

这篇文章基于https://colab.research.google.com/drive/1-ms12IypE-EdDSNjhFMdRdBbMnH94zpH#scrollTo=RQBACAPQZyB-,将介绍使用learned optimizer in the VeLO family:

  • 一个简单的图片识别人物
  • resetnets 下一篇文章

Accelerator Setup、依赖安装和导入

# 设置Accelerator的类型,一般在实验室中只有GPU
Accelerator_Type = 'GPU' #@param ["GPU", "TPU", "CPU"]

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)


# install lopt
# learned_optimization 这个库中包含了
!pip install git+https://github.com/google/learned_optimization.git

# jax 是 TensorFlow 的一个简化库,名为 JAX,结合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加简洁易用。
import jax
if Accelerator_Type == 'TPU':
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

from learned_optimization.tasks.fixed import image_mlp
from learned_optimization.research.general_lopt import prefab
from learned_optimization import eval_training

from matplotlib import pylab as plt
from learned_optimization import notebook_utils as nu
import numpy as onp


from learned_optimization.baselines import utils
import os
# use the precomputed baselines folder from gcp for loading baseline training curves
# 这句话我不是很清楚是什么含义 emm
os.environ["LOPT_BASELINE_ARCHIVES_DIR"] = "gs://gresearch/learned_optimization/opt_archives/"

使用Optax style的优化器

jax有自己的一个示例版优化库optimizers,不过这个库非常的小,都没实现学习率训练计划schedule,当然也可以自己写一个函数,learning_rate_fn(steps),然后作为参数传入optimizers.sgd(step_size=learning_rate_fn)即可。

如果自己写比较麻烦,就可以用optax库。https://zhuanlan.zhihu.com/p/545561011

import optax
# defining an optimizer that targets 1000 training steps
NUM_STEPS = 1000 # 这里是制定优化器要执行的步数
opt = prefab.optax_lopt(NUM_STEPS)  # 定义优化器

定义和执行一个简单的训练循环

# Learned_optimization contains a handful of predefined tasks.  These tasks
# wrap the model initialization and dataset definitions in one convenient
# object.  Here, we initialize a simple MLP for the fashionmnist dataset.
# 一个手动预定义的task,包装了MLP model + fashionmnist dataset
task = image_mlp.ImageMLP_FashionMnist8_Relu32()

# We initialize the underlying MLP and collect its state using its init
# function.  Under the hood, this is really just initializing a haiku model
# as seen here (https://github.com/google/learned_optimization/blob/main/learned_optimization/tasks/fixed/image_mlp.py#L58).
# 初始化这个模型
key = jax.random.PRNGKey(0)
params = task.init(key)

# finally, we initialize the optimizer with the model state:
# 使用model的state来初始化优化器
opt_state = opt.init(params)

# 在训练循环中,我们只需要这么一个update函数
# For a training loop, all we need is an update function.  This update function
# takes existing optimizer state优化器参数, model params模型参数, training data训练数据, and randomness随机数
# as args, and returns new optimizer state, new model params, and the loss.
# import jax 

@jax.jit
def update(opt_state, params, data, key): 
  """Simple training update function.
  Args:
    opt_state: Optimizer state
    params: Model parameter weights
    data: Training data
    key: Jax randomness
  
  Returns:
    A tuple of updated optimizer state, model state, and the current loss.
    返回一个元组:优化器的参数、模型的参数、还有当前的loss,(训练数据已经用了,不需要返回"""
  l, g = jax.value_and_grad(task.loss)(params, key, data)

  # 我猜测:这里的优化器应该是默认frozen的,然后
  updates, opt_state = opt.update(g, opt_state, params=params, extra_args={"loss": l}) 
  params = optax.apply_updates(params, updates)  # 对模型的参数进行更新
  return opt_state, params, l


# a simple training loop

losses = []
for i in range(NUM_STEPS):
  batch = next(task.datasets.train) # 从训练集中拿出数据出来
  key1, key = jax.random.split(key) # 随机数的处理
  opt_state, params, l = update(opt_state, params, batch, key1)  # 执行update函数
  losses.append(l) 

绘制一下loss的图像

# here we visualize the loss during training
plt.plot(losses)

image-20230115211105491

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/165393.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

入门力扣自学笔记230 C++ (题目编号:2293)

2293. 极大极小游戏 题目: 给你一个下标从 0 开始的整数数组 nums ,其长度是 2 的幂。 对 nums 执行下述算法: 设 n 等于 nums 的长度,如果 n 1 ,终止 算法过程。否则,创建 一个新的整数数组 newNums …

【Python百日进阶-数据分析】Day226 - plotly的仪表盘go.Indicator()

文章目录一、语法二、参数三、返回值四、实例4.1 Bullet Charts子弹图4.1.1 基本子弹图4.1.2 添加步骤和阈值4.1.3 自定义子弹4.1.4 多子弹4.2 径向仪表图4.2.1 基本仪表4.2.2 添加步骤、阈值和增量4.2.3 自定义仪表图4.3 组合仪表图4.3.1 组合仪表图4.3.2 单角量规图4.3.3 子弹…

Android 深入系统完全讲解(19)

技术的学习关键点 是什么?思路。 而我这里分享一个学习的经典路线,先厘清总框架,找到思路,然后再逐步击破。 这里关于音视频的就是: 总体分为几部分: 1 绘制 2 编解码格式 3 Android 平台的 FFmpeg 开源移…

Compressed Sensing——从零开始压缩感知

Problem 考虑一个线性方程组求解问题: Axb(1)A x b \tag{1}Axb(1) 其中,A∈RmnA \in\mathbb R^{m\times n}A∈Rmn,x∈Rn1x \in\mathbb R^{n\times 1}x∈Rn1,b∈Rm1b \in\mathbb R^{m\times 1}b∈Rm1且m≪nm \ll nm≪n 这是一个…

【C++11】—— lambda表达式

目录 一、lambda表达式的简介 二、lambda表达式的基本语法 三、lambda表达式的使用方法 四、lambda表达式的底层原理 一、lambda表达式的简介 lambda表达式就类似于仿函数,相比仿函数要更加的简洁,我们看一下下面的代码: //商品类 struct…

【项目实战】使用MybatisPlus乐观锁插件功能

一、背景 当要更新一条记录时,希望这条记录没有被别人更新,可以考虑使用MybatisPlus乐观锁插件功能来实现以上需求。 二、乐观锁介绍 2.1 乐观锁是什么? 乐观锁是一种乐观思想,即认为读多写少,遇到并发的可能性低&…

使用ASM框架创建ClassVisitor时遇到IllegalArgumentException的一种可能解决办法

背景 ASM是java语言中最为广泛使用的插装框架,其优点在于可以动态地在运行时改变java系统的行为,加入我们自己的逻辑。在软件测试领域应用广泛。但是其使用难度很高,一方面使用asm框架需要对java底层知识有较高的了解,另一方面网…

网页共享电脑屏幕与播放(带声音)

这次项目我们是写的一个课堂辅助软件的网页版,其中有一个功能感觉能作为我们项目的一个亮点,就是直播功能,在之前并没有写过这个东西。虽然现在这个功能还不知道怎么写,但是它的流程终归是利用视频流将本地的视频给共享出去&#…

Verilog:【8】基于FPGA实现SD NAND FLASH的SPI协议读写

碎碎念: 终于熬过了期末周,可以开始快乐的开发之旅了。 这一期作为一千粉后的首篇博客,由于之后项目会涉及到相关的部分,因此介绍的是使用FPGA实现SD NAND FLASH的读写操作,以雷龙科技提供的SD NAND FLASH样品为例&…

实证分析权重系数计算大全

在实际研究中,权重计算是一种常见的分析方法,需要结合数据的特征情况进行选择,比如数据之间的波动性是一种信息量,那么可考虑使用CRITIC权重法或信息量权重法;也或者专家打分数据,那么可使用AHP层次法或优序…

直观感受PromQL及其数据类型

由于PromQL内容较多,将内容分为三篇文章讲述: 一、直观感受PromQL及其数据类型 二、PromQL之选择器和运算符 三、PromQL之函数 想必都知道要使用Msql,必须会用SQL,同样要使用Prometheus 就要掌握PromQL(Prometheus Que…

【链表】leetcode142.环形链表II(C/C++/Java/Js)

leetcode142.环形链表II1 题目2 思路2.1 判断链表是否有环--快慢指针法2.2 如果有环,如何找到这个环的入口2.3 补充3 代码3.1 C版本3.2 C版本3.3 Java版本3.4 JavaScript版本4 总结1 题目 题源链接 给定一个链表的头节点 head ,返回链表开始入环的第一个…

软测复习05:基于质量特征的测试

作者:非妃是公主 专栏:《软件测试》 个性签:顺境不惰,逆境不馁,以心制境,万事可成。——曾国藩 文章目录性能测试压力测试容量测试健壮性测试安全性测试可靠性测试恢复性测试协议一致性测试兼容性测试安装…

【数据结构】保姆级单链表教程(概念、分类与实现)

目录 🍊前言🍊: 🍈一、链表概述🍈: 1.链表的概念及结构: 2.链表存在的意义: 🍓二、链表的分类🍓: 🥝三、单链表的实现&#x1f…

​盘点几款国内外安全稳定的域名解析平台​

众所周知,有了域名后想建站使用,必须要先解析域名。域名使用注册商一般会提供域名解析服务,这虽然为用户提供了方便,但功能大多有限,使用第三方域名解析平台就成了非常必要的选择。今天,小编就为大家盘点几…

计算机视觉OpenCv学习系列:第四部分、键盘+鼠标响应操作

第四部分、键盘鼠标响应操作第一节、键盘响应操作1.键盘响应事件2.键盘响应3.代码练习与测试第二节、鼠标操作与响应1.鼠标事件与回调2.鼠标操作3.代码练习与测试学习参考第一节、键盘响应操作 键盘响应中有一个函数叫做waitKey,所有的获取键盘键值都是通过waitKey…

【经典笔试题】动态内存管理

test1:void GetMemory(char* p) {p (char*)malloc(100); } void Test(void) {char* str NULL;GetMemory(str);strcpy(str, "hello world");printf(str); }int main() {Test();return 0; }请问执行上面代码,会出现什么结果?解析&a…

7. R语言【独立性检验】:卡方独立性检验、Fisher精确检验 、Cochran-Mantel-Haenszel检验

文章目录1. 卡方检验2. 费希尔精确检验(Fisher Exact Test)3. Cochran-Mantel-Haenszel检验独立性检验:用来判断变量之间相关性的方法,如果两个变量彼此独立,那么两者统计上就是不相关的 1. 卡方检验 可以使用chisq.…

Java面向对象之多态、内部类、常用API

目录面向对象之三大特性之三:多态多态的概述、多态的形式多态的好处多态下引用数据类型的类型转换多态的综合案例内部类内部类概述内部类之一:静态内部类内部类之二:成员内部类内部类之三:局部内部类内部类之四:匿名内…

JavaSE与网络面试题

大佬的: https://github.com/Snailclimb/JavaGuide https://osjobs.net/topk/all/ 自增自减 要点: 赋值 ,最后计算 右边的从左到右加载值,一次压入操作数栈 实际先算哪个看运算符的优先级 自增、自减操作都是直接修改变量…