基于强化学习的自动化裁剪CIFAR-10 分类任务(提升模型精度+减少计算量)

news2024/10/6 10:34:12

基于强化学习的自动化裁剪,提升模型精度的同时减少计算量。

介绍

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RFnHlyQG-1691544546106)(./pic/APT-main.png)]

目前的强化学习工作很多集中在利用外部环境的反馈训练agent,忽略了模型本身就是一种能够获得反馈的环境。本项目的核心思想是:将模型视为环境,构建附生于模型的 agent ,以辅助模型进一步拟合真实样本。

大多数领域的模型都可以采用这种方式来优化,如cv\多模态等。它至少能够以三种方式工作:1.过滤噪音信息,如删减语音或图像特征;2.进一步丰富表征信息,如高效引用外部信息;3.实现记忆、联想、推理等复杂工作,如构建重要信息的记忆池。

这里推出一款早期完成的裁剪机制transformer版本(后面称为APT),实现了一种更高效的训练模式,能够优化模型指标;此外,可以使用动态图丢弃大量的不必要单元,在指标基本不变的情况下,大幅降低计算量。

该项目希望为大家抛砖引玉。

在这里插入图片描述

为什么要做自动剪枝

在具体任务中,往往存在大量毫无价值的信息和过渡性信息,有时不但对任务无益,还会成为噪声。比如:表述会存在冗余/无关片段以及过渡性信息;动物图像识别中,有时候背景无益于辨别动物主体,即使是动物部分图像,也仅有小部分是关键的特征。

在这里插入图片描述

以transformer为例,在进行self-attention计算时其复杂度与序列长度平方成正比。长度为10,复杂度为100;长度为9,复杂度为81。

利用强化学习构建agent,能够精准且自动化地动态裁剪已丧失意义部分,甚至能将长序列信息压缩到50-100之内(实验中有从500+的序列长度压缩到个位数的示例),以大幅减少计算量。

实验中,发现与裁剪agent联合训练的模型比普通方法训练的模型效果要更好。

使用说明

环境

torch
numpy
tqdm
tensorboard
ml-collections

下载经过预先​​训练的模型(来自Google官方)

本项目使用的型号:ViT-B_16(您也可以选择其它型号进行测试)

# imagenet21k pre-train
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz

训练与推理

下载好预训练模型就可以跑了。

# 训练
python3 train.py --name cifar10-100_500 --dataset cifar100 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz

# 推理
python3 infer.py --name cifar10-100_500 --dataset cifar100 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz

CIFAR-10和CIFAR-100会自动下载和培训。如果使用其他数据集,您需要自定义data_utils.py。

在裁剪模式的推理过程中,预期您将看到如下格式的输出。

Validating... (loss=0.13492):   1%|| 60/10000 [00:01<02:36, 63.34it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 196, 768])
第5层形状::: torch.Size([1, 188, 768])
第8层形状::: torch.Size([1, 186, 768])
Validating... (loss=0.01283):   1%|| 60/10000 [00:01<02:36, 63.34it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 183, 768])
第5层形状::: torch.Size([1, 166, 768])
第8层形状::: torch.Size([1, 166, 768])
Validating... (loss=3.71401):   1%|| 60/10000 [00:01<02:36, 63.34it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 193, 768])
第5层形状::: torch.Size([1, 191, 768])
第8层形状::: torch.Size([1, 186, 768])
Validating... (loss=0.00328):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 191, 768])
第5层形状::: torch.Size([1, 164, 768])
第8层形状::: torch.Size([1, 123, 768])
Validating... (loss=0.03190):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 193, 768])
第5层形状::: torch.Size([1, 187, 768])
第8层形状::: torch.Size([1, 160, 768])
Validating... (loss=0.00356):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 193, 768])
第5层形状::: torch.Size([1, 187, 768])
第8层形状::: torch.Size([1, 182, 768])
Validating... (loss=0.00297):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 197, 768])
第5层形状::: torch.Size([1, 167, 768])
第8层形状::: torch.Size([1, 162, 768])
Validating... (loss=0.00162):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 189, 768])
第5层形状::: torch.Size([1, 179, 768])
第8层形状::: torch.Size([1, 157, 768])
Validating... (loss=0.08821):   1%|| 67/10000 [00:01<02:35, 63.93it/s]
初始输入形状::: torch.Size([1, 197, 768])
第2层形状::: torch.Size([1, 197, 768])
第5层形状::: torch.Size([1, 174, 768])
第8层形状::: torch.Size([1, 156, 768])

默认的batch size为72、gradient_accumulation_steps为3。当GPU内存不足时,您可以通过它们来进行训练。

注:相较于原始的ViT,APT(Automatic pruning transformer)的训练步数、训练耗时都会上升。原因是使用pruning agent的模型由于总会丢失部分信息,使得收敛速度变慢,同时为了训练pruning agent,也需要多次的观测、行动、反馈。

模型介绍

模型主体

基于transformer的视觉预训练模型ViT是本项目的模型主体,具体细节可以查看论文:《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》
在这里插入图片描述

自动化裁剪的智能体

对于强化学习agent来说,最关键的问题之一是如何衡量动作带来的反馈。为了评估单次动作所带来的影响,使用了以下三步骤:

1.使用一个普通模型(无裁剪模块)进行预测;

2.使用一个带裁剪器的模型(执行一次裁剪动作)进行预测;

3.对比两次预测的结果,若裁剪后损失相对更小,则说明该裁剪动作帮助了模型进一步拟合真实状况,应该得到奖励;反之,应该受到惩罚。

但是在实际预测过程中,模型是同时裁剪多个单元的,这或将因为多个裁剪的连锁反应而导致模型失效。训练过程中需要构建一个带裁剪器的模型(可执行多次裁剪动作),以减小该问题所带来的影响。

综上,本模型使用的是三通道模式进行训练。

在这里插入图片描述

关于裁剪器的模型结构设计,本模型中认为如何衡量一个信息单元是否对模型有意义,建立于其自身的信息及它与任务的相关性上。

因此以信息单元本身及它与CLS单元的交互作为agent的输入信息。

在这里插入图片描述

实验


数据集ViTAPT(pruning)APT(no pruning)
CIFAR-10092.392.693.03
CIFAR-1099.0898.9398.92

以上加载的均为ViT-B_16,resolution为224*224。

致谢

感谢基于pytorch的图像分类项目,本项目是在此基础上做的研发。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/879632.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据:什么是数据分析及环境搭建

一、什么是数据分析 当今世界对信息技术的依赖程度在不断加深&#xff0c;每天都会有大量的数据产生&#xff0c;我们经常会感到数据越来越多&#xff0c;但是要从中发现有价值的信息却越来越难。这里所说的信息&#xff0c;可以理解为对数据集处理之后的结果&#xff0c;是从…

【Sklearn】基于逻辑回归算法的数据分类预测(Excel可直接替换数据)

【Sklearn】基于逻辑回归算法的数据分类预测&#xff08;Excel可直接替换数据&#xff09; 1.模型原理2.模型参数3.文件结构4.Excel数据5.下载地址6.完整代码7.运行结果 1.模型原理 逻辑回归是一种用于二分类问题的统计学习方法&#xff0c;尽管名字中含有“回归”&#xff0c…

ORCA优化器浅析——IMDRelation Storage type of a relation GP6与GP7对比

如上图所示IMDRelation作为Interface for relations in the metadata cache&#xff0c;其定义了Storage type of a relation表的存储类型&#xff0c;如下所示&#xff1a; enum Erelstoragetype {ErelstorageHeap,ErelstorageAppendOnlyCols,ErelstorageAppendOnlyRows,Erels…

如何使用CSS实现一个模态框(Modal)效果?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 使用CSS实现模态框&#xff08;Modal&#xff09;效果⭐ HTML 结构⭐ CSS 样式⭐ JavaScript⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎…

重磅:谷歌发布多平台应用 AI 编程神器

前几天&#xff0c; 谷歌发布了一个多平台应用开发神器&#xff1a;IDX 。 IDX 背靠 AI 编程神器 Codey&#xff0c;支持 React、Vue 等框架&#xff0c;还能补全、解释代码。 更有特色的一点就是&#xff1a;这是一款基于浏览器的开发全栈、用于多平台应用开发的工具。 这款开…

C语言题目的多种解法分享 2之字符串左旋和补充题

前言 有的时候&#xff0c;这个系列专栏中的解法之间并无优劣&#xff0c;只是给大家提供不同的解题思路 我决定将代码实现的过程写成注释&#xff0c;方便大家直接找到对应的函数&#xff0c;只有需要补充说明的知识才会单拿出来强调 这个系列的文章会更的比较慢&#xff0…

级联(数据字典)

二级级联&#xff1a; 一&#xff1a;新建两个Bean 父级&#xff1a; /*** Description 数据字典* Author WangKun* Date 2023/7/25 10:15* Version*/ Data AllArgsConstructor NoArgsConstructor TableName("HW_DICT_KEY") public class DictKey implements Seri…

学习笔记整理-JS-06-函数

一、函数基本使用 1. 什么是函数 函数就是语句的封装&#xff0c;可以让这些代码方便地被复用。函数具有"一次定义&#xff0c;多次调用"的优点。使用函数&#xff0c;可以简化代码&#xff0c;让代码更具有可读性。 2. 函数的定义和调用 和变量类似&#xff0c;函…

C++:模拟实现list及迭代器类模板优化方法

文章目录 迭代器模拟实现 本篇模拟实现简单的list和一些其他注意的点 迭代器 如下所示是利用拷贝构造将一个链表中的数据挪动到另外一个链表中&#xff0c;构造两个相同的链表 list(const list<T>& lt) {emptyinit();for (auto e : lt){push_back(e);} }void test_…

【ES】【elasticsearch】分布式搜索

文章目录 ☀️安装elasticsearch☀️1.部署单点es&#x1f338;1.1.创建网络&#x1f338;1.2.下载镜像&#x1f338;1.3.运行 ☀️2.部署kibana&#x1f338;2.1.部署&#x1f338;2.2.DevTools ☀️3.安装IK分词器&#x1f338;3.1.在线安装ik插件&#xff08;较慢&#xff0…

ARM汇编快速入门

本文主要分享如何快速上手ARM汇编开发的经验、汇编开发中常见的Bug以及Debug方法、用的Convolution Dephtwise算子的汇编实现相对于C版本的加速效果三方面内容。 前言 神经网络模型能够在移动端实现快速推理离不开高性能算子&#xff0c;直接使用ARM汇编指令来进行算子开发无疑…

ad+硬件每日学习十个知识点(32)23.8.12 (元器件封装、PCB封装、3D的PCB封装)

文章目录 1.元器件封装属性值说明2.PCB封装标准说明&#xff08;M、N、L&#xff09;3.电阻的PCB封装&#xff08;阻焊层&#xff09;4.电感的PCB封装&#xff08;CD、CDRH&#xff09;1.CD31的意思是&#xff0c;直径3mm&#xff0c;高度1mm![在这里插入图片描述](https://img…

【SQL应知应会】索引(二)• MySQL版

欢迎来到爱书不爱输的程序猿的博客, 本博客致力于知识分享&#xff0c;与更多的人进行学习交流 本文收录于SQL应知应会专栏,本专栏主要用于记录对于数据库的一些学习&#xff0c;有基础也有进阶&#xff0c;有MySQL也有Oracle 索引 • MySQL版 前言一、索引1.简介2.创建2.1 索引…

Gradio——快速部署可视化人智能应用

前言 Gradio是一个开源的Python库&#xff0c;用于快速构建机器学习和数据科学演示的应用。它可以帮助你快速创建一个简单漂亮的用户界面&#xff0c;以便向客户、合作者、用户或学生展示你的机器学习模型。此外&#xff0c;还可以通过自动共享链接快速部署模型&#xff0c;并获…

IntelliJ IDEA热部署:JRebel插件的安装与使用

热部署 概述JRebel 概述 热部署&#xff0c;指修改代码后&#xff0c;无需停止应用程序&#xff0c;即可使修改后的代码生效&#xff0c;其有利于提高开发效率。 热部署方式&#xff1a; 手动热部署&#xff1a;修改代码后&#xff0c;重新编译项目&#xff0c;然后启动应用程…

【软件测试】Linux系统下安装jdk配置环境变量(详细步骤)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、安装环境 操作…

UGUI事件系统EventSystem

一. 事件系统概述 Unity的事件系统具有通过鼠标、键盘、游戏控制柄、触摸操作等输入方式&#xff0c;将事件发送给对象的功能。事件系统通过场景中EventSystem对象的组件EventSystem和Standalone Input Module发挥功能。EventSystem对象通常实在创建画布的同时被创建的&#xf…

c++杂谈-5

目录 一、一些c11新特性1. decltype关键字及函数后置返回类型2. 函数后置返回类型 二、c的内存模型三、函数指针和回调函数四、函数模版的注意事项 一、一些c11新特性 1. decltype关键字及函数后置返回类型 在C11中&#xff0c;decltype操作符&#xff0c;用于查询表达式的数…

HarmonyOS NEXT新能力,一站式高效开发HarmonyOS应用

2023年8月6日华为开发者大会2023&#xff08;HDC.Together&#xff09;圆满收官&#xff0c;伴随着HarmonyOS 4的发布&#xff0c;华为向开发者发布了汇聚所有最新开发能力的HarmonyOS NEXT开发者预览版&#xff0c;并分享了围绕“一次开发&#xff0c;多端部署” “可分可合&a…

通过Microsoft Loopback Adapter实现虚拟机和物理机的通信

问题 问&#xff1a;不借助路由器或交换机的情况下&#xff0c;能不能实现主机和虚拟及之间两个软件的通信呢&#xff1f;要求主机和虚拟及均有独立的ip地址&#xff0c;从而进行指定源的组播通信。 答&#xff1a;可以。通过借助虚拟网络适配器&#xff0c;不需要路由器或交…