yolov11剪枝

news2024/11/30 18:44:22

思路:yolov11中的C3k2与yolov8的c2f的不同,所以与之前yolov8剪枝有稍许不同;

后续:会将剪枝流程写全,以及增加蒸馏、注意力、改loss;

注意:

1.在代码105行修改pruning.get_threshold(yolo.model, 0.65),可以获得不同的剪枝率;

2.改代码放在训练代码同一页面下即可;

3.在最后修改文件夹地址来获得剪枝后的模型;

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os


# os.environ["CUDA_VISIBLE_DEVICES"] = "2"


class PRUNE():
    def __init__(self) -> None:
        self.threshold = None

    def get_threshold(self, model, factor=0.8):
        ws = []
        bs = []
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                w = m.weight.abs().detach()
                b = m.bias.abs().detach()
                ws.append(w)
                bs.append(b)
                print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
                print()
        # keep
        ws = torch.cat(ws)
        self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]

    def prune_conv(self, conv1: Conv, conv2: Conv):
        ## Normal Pruning
        gamma = conv1.bn.weight.data.detach()
        beta = conv1.bn.bias.data.detach()

        keep_idxs = []
        local_threshold = self.threshold
        while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
            keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
            local_threshold = local_threshold * 0.5
        n = len(keep_idxs)
        # n = max(int(len(idxs) * 0.8), p)
        print(n / len(gamma) * 100)
        conv1.bn.weight.data = gamma[keep_idxs]
        conv1.bn.bias.data = beta[keep_idxs]
        conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
        conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
        conv1.bn.num_features = n
        conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
        conv1.conv.out_channels = n

        if isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":
            proto = conv2.pop()
            proto.cv1.conv.in_channels = n
            proto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]
        if conv1.conv.bias is not None:
            conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

        ## Regular Pruning
        if not isinstance(conv2, list):
            conv2 = [conv2]
        for item in conv2:
            if item is None: continue
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            if isinstance(item, Sequential):
                conv1 = item[0]
                conv = item[1].conv
                conv1.conv.in_channels = n
                conv1.conv.out_channels = n
                conv1.conv.groups = n
                conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]
                conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]
                conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]
                conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
                conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
                conv1.bn.num_features = n
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]

    def prune(self, m1, m2):
        if isinstance(m1, C3k2):  # C3k2 as a top conv
            m1 = m1.cv2
        if isinstance(m1, Sequential):
            m1 = m1[1]
        if not isinstance(m2, list):  # m2 is just one module
            m2 = [m2]
        for i, item in enumerate(m2):
            if isinstance(item, C3k2) or isinstance(item, SPPF):
                m2[i] = item.cv1

        self.prune_conv(m1, m2)


def do_pruning(modelpath, savepath):
    pruning = PRUNE()

    ### 0. 加载模型
    yolo = YOLO(modelpath)  # build a new model from scratch
    pruning.get_threshold(yolo.model, 0.65)  # 这里的0.8为剪枝率。

    ### 1. 剪枝C3k2 中的Bottleneck
    for name, m in yolo.model.named_modules():
        if isinstance(m, Bottleneck):
            pruning.prune_conv(m.cv1, m.cv2)

    ### 2. 指定剪枝不同模块之间的卷积核
    seq = yolo.model.model
    for i in [3, 5, 7, 8]:
        pruning.prune(seq[i], seq[i + 1])

    ### 3. 对检测头进行剪枝
    # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
    # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
    # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
    detect: Detect = seq[-1]
    proto = detect.proto
    last_inputs = [seq[16], seq[19], seq[22]]
    colasts = [seq[17], seq[20], None]
    for idx, (last_input, colast, cv2, cv3, cv4) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3, detect.cv4)):
        if idx == 0:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0], proto])
        else:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0]])
        pruning.prune(cv2[0], cv2[1])
        pruning.prune(cv2[1], cv2[2])
        pruning.prune(cv3[0], cv3[1])
        pruning.prune(cv3[1], cv3[2])
        pruning.prune(cv4[0], cv4[1])
        pruning.prune(cv4[1], cv4[2])

    ### 4. 模型梯度设置与保存
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True

    yolo.val(data='data.yaml', batch=2, device=0, workers=0)
    torch.save(yolo.ckpt, savepath)



if __name__ == "__main__":
    modelpath = "runs/segment/Constraint/weights/best.pt"
    savepath = "runs/segment/Constraint/weights/last_prune.pt"
    do_pruning(modelpath, savepath)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2250612.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pcb线宽与电流

三十年一路高歌猛进的中国经济&#xff0c; 中国经历了几个三十年&#xff1f; 第一个三十年&#xff1a;以计划为导向。 第二个三十年&#xff1a;以经济为导向。 现在&#xff0c;第三个三十年呢&#xff1f; 应该是以可持续发展为导向。 传统企业摇摇欲坠&#xff0c; 新兴企…

23种设计模式-抽象工厂(Abstract Factory)设计模式

文章目录 一.什么是抽象工厂设计模式&#xff1f;二.抽象工厂模式的特点三.抽象工厂模式的结构四.抽象工厂模式的优缺点五.抽象工厂模式的 C 实现六.抽象工厂模式的 Java 实现七.代码解析八.总结 类图&#xff1a; 抽象工厂设计模式类图 一.什么是抽象工厂设计模式&#xff1f…

VSCode修改资源管理器文件目录树缩进(VSCode目录结构、目录缩进、文件目录外观)workbench.tree.indent

文章目录 方法点击左下角小齿轮点击设置点击工作台&#xff0c;点击外观&#xff0c;找到Tree: Indent设置目录树的缩进 方法 点击左下角小齿轮 点击设置 点击工作台&#xff0c;点击外观&#xff0c;找到Tree: Indent设置目录树的缩进 "workbench.tree.indent"默认…

Transformer.js(七):ONNX 后端介绍 - 它是什么、如何将pytorch模型导出为ONNX格式并在web中使用

在前面的文章中&#xff0c;我介绍了关于transformer.js的一些内容&#xff0c;快速连接&#xff1a; 1. 运行框架的可运行环境、使用方式、代码示例以及适合与不适合的场景2. 关于pipe管道的一切3. 底层架构及性能优化指南4. 型接口介绍5. Tokenizer 分词器接口解析 6. 处理工…

玄机应急:linux入侵排查webshell查杀日志分析

目录 第一章linux:入侵排查 1.web目录存在木马&#xff0c;请找到木马的密码提交 2.服务器疑似存在不死马&#xff0c;请找到不死马的密码提交 3.不死马是通过哪个文件生成的&#xff0c;请提交文件名 4.黑客留下了木马文件&#xff0c;请找出黑客的服务器ip提交 5.黑客留…

消息队列详解:从基础到高级应用

本文主旨 撰写这篇文章的目的在于向读者提供一个全面理解消息队列概念及其在实际应用中重要性的指南。通过从RocketMQ的基础组件如生产者、消费者、主题等的介绍到更高级的概念&#xff0c;比如集群消费与广播消费的区别、顺序消息的重要性等&#xff0c;我们希望能够帮助开发…

qt QGraphicsRotation详解

1、概述 QGraphicsRotation 是 Qt 框架中 QGraphicsTransform 的一个子类&#xff0c;它专门用于处理图形项的旋转变换。通过 QGraphicsRotation&#xff0c;你可以对 QGraphicsItem&#xff08;如形状、图片等&#xff09;进行旋转操作&#xff0c;从而创建动态和吸引人的视觉…

20241129解决在Ubuntu20.04下编译中科创达的CM6125的Android10出现找不到库文件

20241129解决在Ubuntu20.04下编译中科创达的CM6125的Android10出现找不到库文件libtinfo.so.5的问题 2024/11/29 20:41 缘起&#xff1a;中科创达的高通CM6125开发板的Android10的编译环境需要。 [ 11% 15993/135734] target Java source list: vr [ 11% 15994/135734] target …

云轴科技ZStack助力 “上科大智慧校园信创云平台”入选上海市2024年优秀信创解决方案

近日&#xff0c;为激发创新活⼒&#xff0c;促进信创⾏业⾼质量发展&#xff0c;由上海市经济信息化委会同上海市委网信办、上海市密码管理局、上海市国资委等主办的“2024年上海市优秀信创解决方案”征集遴选活动圆满落幕。云轴科技ZStack支持的“上科大智慧校园信创云平台”…

【ArcGIS Pro】实现一下完美的坐标点标注

在CAD里利用湘源可以很快点出一个完美的坐标点标注。 但是在ArcGIS Pro中要实现这个效果却并不容易。 虽然有点标题党&#xff0c;这里就尽量在ArcGIS Pro中实现一下。 01 标注实现方法 首先是准备工作&#xff0c;准备一个点要素图层&#xff0c;包含xy坐标字段。 在地图框…

聚云科技×亚马逊云科技:打通生成式AI落地最后一公里

云计算时代&#xff0c;MSP&#xff08;云管理服务提供商&#xff09;犹如一个帮助企业上云、用云、管理云的专业管家&#xff0c;在云计算厂商与企业之间扮演桥梁的作用。生成式AI浪潮的到来&#xff0c;也为MSP带来全新的生态价值和发展空间。 作为国内领先的云管理服务提供…

brew安装mongodb和php-mongodb扩展新手教程

1、首先保证macos下成功安装了Homebrew&#xff0c; 在终端输入如下命令&#xff1a; brew search mongodb 搜索是不是有mongodb资源&#xff0c; 演示效果如下&#xff1a; 2、下面来介绍Brew 安装 MongoDB&#xff0c;代码如下&#xff1a; brew tap mongodb/brew brew in…

图像显示的是矩阵的行和列,修改为坐标范围。

x 3; y 3; f1x x^2 y^2; guance1 f1x; F (x, y) sqrt((x.^2 y.^2 - guance1).^2); % 使用点乘 [x, y] meshgrid(0:1:5, 0:1:5); Z F(x, y); figure; imagesc(Z); % 由于 imagesc 使用矩阵索引作为坐标&#xff0c;我们需要手动添加刻度 % 这里我们假设 x 和 y 的范围…

深入理解Redis线程模型

前置目标&#xff1a;搭建一个Redis单机服务器。搭建过程参考前面的文档&#xff08;https://blog.csdn.net/Zhuxiaoyu_91/article/details/143904807&#xff09;。 建议调整的redis核心配置&#xff1a; daemonize yes # 允许后台启动 protected‐mode no #关闭保护模…

机器学习实战:泰坦尼克号乘客生存率预测(数据处理+特征工程+建模预测)

项目描述 任务&#xff1a;根据训练集数据中的数据预测泰坦尼克号上哪些乘客能生存下来 数据源&#xff1a;csv文件&#xff08;train.csv&#xff09; 目标变量&#xff1a;Survived&#xff08;0-1变量&#xff09; 数据集预览&#xff1a; 1、英文描述&#xff1a; 2、…

人工智能之数学基础:欧式距离及在人工智能领域中的应用

本文重点 欧式距离,也称为欧几里得距离,是数学中用于衡量多维空间中两点之间绝对距离的一种基本方法。这一概念最早由古希腊数学家欧几里得提出,并以其名字命名。欧式距离的计算基于勾股定理,即在一个直角三角形中,斜边的平方等于两直角边的平方和。在多维空间中,欧式距…

logminer挖掘日志归档查找问题

--根据发生问题时间点查找归档文件 select first_time,NAME from gv$archived_log where first_time>2016-03-15 17:00:00 and first_time<2016-03-15 21:00:00; 2016-03-15 17:23:55 ARCH/jxdb/archivelog/2016_03_15/thread_1_seq_41588.4060.906577337 2016-03-15 17:…

洛谷 P1747 好奇怪的游戏 C语言 bfs

题目&#xff1a; https://www.luogu.com.cn/problem/P1747#submit 题目描述 爱与愁大神坐在公交车上无聊&#xff0c;于是玩起了手机。一款奇怪的游戏进入了爱与愁大神的眼帘&#xff1a;***&#xff08;游戏名被打上了马赛克&#xff09;。这个游戏类似象棋&#xff0c;但…

【c++篇】:解读Set和Map的封装原理--编程中的数据结构优化秘籍

✨感谢您阅读本篇文章&#xff0c;文章内容是个人学习笔记的整理&#xff0c;如果哪里有误的话还请您指正噢✨ ✨ 个人主页&#xff1a;余辉zmh–CSDN博客 ✨ 文章所属专栏&#xff1a;c篇–CSDN博客 文章目录 前言一.set和map的初步封装1.树的节点封装修改2.Find()查找函数3.红…

字符型注入‘)闭合

前言 进行sql注入的时候&#xff0c;不要忘记闭合&#xff0c;先闭合再去获取数据 步骤 判断是字符型注入 用order by获取不了显位&#xff0c;select也一样 是因为它是’)闭合&#xff0c;闭合之后&#xff0c;就可以获取数据了 最后就是一样的步骤