VGGNet剪枝实战:使用VGGNet训练、稀疏训练、剪枝、微调等,剪枝出只有3M的模型(二)

news2024/9/20 5:42:33

文章目录

  • 稀疏训练VGGNet
  • 剪枝
    • 导入库文件
    • 测试函数
    • 定义全局参数
    • BN通道排序
    • 制作Mask
    • 剪枝操作
  • 微调
    • 微调方法
    • 微调结果

稀疏训练VGGNet

新建train_sp.py脚本。稀疏化训练的过程和正常训练类似,不同的是在BN层中各权重加入稀疏因子,代码如下:

def updateBN(model,s=0.0001,epoch=1,epochs=1000):
    srtmp = s * (1 - 0.9 * epoch / epochs)
    for m in model.modules():
        if isinstance(m,nn.BatchNorm2d):
            m.weight.grad.data.add_(srtmp*torch.sign(m.weight.data))
            m.bias.grad.data.add_(s * 10 * torch.sign(m.bias.data))

加入到train函数中,如图:
在这里插入图片描述
s的设置需要根据数据集调整,可以通过观察tensorboard的map,gamma变化直方图等选择。我在本次训练种使用的是0.001.

训练完成后,就可以使用tensorboard观察训练结果,在根目录运行:

tensorboard --logdir .

然后看到如下信息:

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.13.0 at http://localhost:6006/ (Press CTRL+C to quit)

在浏览器中打开http://localhost:6006/就能看到。
在这里插入图片描述
在这里插入图片描述
蓝色的是正常训练,BN权重的分布情况。紫红色的是加入稀疏因子后BN权重的分布情况。
稀疏化训练结果:
在这里插入图片描述
在这里插入图片描述
结果基本上和正常训练一致!最终结果也是95.6%。

剪枝

新建prune.py脚本,这个脚本是剪枝脚本。

导入库文件

import os
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms

from vgg import VGG
import numpy as np

测试函数

测试函数,用来测试剪枝后的模型ACC,代码如下:

# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test():
    # 读取数据
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48214436, 0.42969334, 0.33318862], std=[0.2642221, 0.23746745, 0.21696019])
    ])
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)

    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)
    model.eval()
    correct = 0
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

定义全局参数

if __name__ == '__main__':
    BATCH_SIZE=16
    percent=0.7
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model_path='checkpoints/vgg_sp/best.pth'
    save_name='pruned.pth'

BATCH_SIZE:测试函数的BatchSize。
percent:剪枝的比率。
DEVICE :如果有显卡则使用GPU,没有则使用cpu。
model_path:稀疏训练模型的路径。
save_name:剪枝后,模型的路径。

BN通道排序

    #加载稀疏训练的模型
    model = torch.load(model_path)
    print(model)
    total = 0 # 统计所有BN层的参数量
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0] # 每个BN层权重w参数量
    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        #将各个BN层的参数值复制到bn中
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index + size)] = m.weight.data.abs().clone()
            index += size
    #对bn中的weight值排序
    y, i = torch.sort(bn)#
    thre_index = int(total * percent)
    thre = y[thre_index]#取bn排序后的第thresh_index索引值为bn权重的截断阈值

制作Mask

    pruned = 0 #统计BN层剪枝通道数
    cfg = []#统计保存通道数
    cfg_mask = []#BN层权重矩阵,剪枝的通道记为0,未剪枝通道记为1
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            weight_copy = m.weight.data.clone()
            mask = weight_copy.abs().gt(thre).float().cuda()#阈值分离权重
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)#更新BN层的权重,剪枝通道的权重值为0
            m.bias.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))#记录未被剪枝的通道数量
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                  format(k, mask.shape[0], int(torch.sum(mask))))
        elif isinstance(m, nn.MaxPool2d):
            cfg.append('M')

    pruned_ratio = pruned / total
    print('Pre-processing Successful!')
    test()
    # Make real prune
    print(cfg)

剪枝操作

    newmodel = VGG(cfg=cfg,num_classes=12)
    newmodel.cuda()
    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        if isinstance(m0, nn.BatchNorm2d):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[idx1].clone()
            m1.bias.data = m0.bias.data[idx1].clone()
            m1.running_mean = m0.running_mean[idx1].clone()
            m1.running_var = m0.running_var[idx1].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
            w = m0.weight.data[:, idx0, :, :].clone()
            w = w[idx1, :, :, :].clone()
            m1.weight.data = w.clone()
            # m1.bias.data = m0.bias.data[idx1].clone()
        elif isinstance(m0, nn.Linear):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[:, idx0].clone()
    torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, save_name)
    print(newmodel)
    model = newmodel
    test()

剪枝后保存模型

微调

微调方法

微调方法和正常训练类似,加载剪枝后的模型和配置,然后训练、验证即可!

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'checkpoints/vgg_pruned'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)
    # 设置全局参数
    model_lr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 300
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    classes = 12
    resume = 'pruned.pth'

    # 设置模型
    model = torch.load(resume)
    model_ft=VGG(cfg=model['cfg'],num_classes=classes)
    model_ft.load_state_dict(model['state_dict'])
    model_ft.to(DEVICE)
    print(model_ft)

微调结果

Val set: Average loss: 0.2845, Accuracy: 457/482 (95%)

                           precision    recall  f1-score   support

              Black-grass       0.79      0.86      0.83        36
                 Charlock       1.00      1.00      1.00        42
                 Cleavers       1.00      0.96      0.98        50
         Common Chickweed       0.94      0.91      0.93        34
             Common wheat       0.93      1.00      0.97        42
                  Fat Hen       0.97      0.97      0.97        34
         Loose Silky-bent       0.88      0.78      0.83        46
                    Maize       0.96      1.00      0.98        45
        Scentless Mayweed       0.93      0.96      0.95        45
          Shepherds Purse       0.97      0.97      0.97        35
Small-flowered Cranesbill       1.00      0.97      0.99        36
               Sugar beet       1.00      1.00      1.00        37

                 accuracy                           0.95       482
                macro avg       0.95      0.95      0.95       482
             weighted avg       0.95      0.95      0.95       482

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/856491.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

通达OA前台getshell

你永远要记住&#xff0c;你跟别人不一样&#xff0c;你是一个有另外世界的人 漏洞复现 访问网站url&#xff1a; 将exp中的target替换为目标url Payload替换为自己的木马即可 ​关键exp如下&#xff1a; 执行exp漏洞利用&#xff0c;中间过程会询问是否删除文件&#x…

Mysql load data隐藏字符特殊字符ESC的方法

mysqlimport --userroot --passwordroot1234 --default-character-setutf8 --fields-terminated-by"$(echo -ne \033)" --lines-terminated-by\n --local shenl /root/xab.dat https://jisuan5.com/ascii/ ascii码值 SHOW GLOBAL VARIABLES LIKE local_infile; LO…

微服务Ribbon-负载均衡原理

目录 一、LoadBalancerIntercepor 二、LoadBalancerClient 三、负载均衡策略IRule 四、总结 上一篇中&#xff0c;我们添加了LoadBalanced注解&#xff0c;即可实现负载均衡功能&#xff0c;这是什么原理呢&#xff1f; SpringCloud底层其实是利用了一个名为Ribbon的组件&…

SpringBoot学习——springboot整合email springboot整合阿里云短信服务

目录 引出springboot整合email配置邮箱导入依赖application.yml配置email业务类测试类 springboot整合阿里云短信服务申请阿里云短信服务测试短信服务获取阿里云的accessKeyspringboot整合阿里云短信导包工具类 总结 引出 1.springboot整合email&#xff0c;qq邮箱&#xff0c;…

计算机网络实验3:双绞线跳线的制作和测试

文章目录 1. 主要教学内容2. 双绞线跳线的制作和测试 1. 主要教学内容 实验内容&#xff1a;掌握双绞线制作过程中的剥线、理线、插线、压线以及测线。所需学时&#xff1a;2。重难点&#xff1a;双绞线的类别及其用途周次&#xff1a;第2周。教材相关章节&#xff1a;第5章&a…

Android Studio实现Spinner下拉列表

效果图 点击下拉列表 点击某一个下拉列表 MainActivity package com.example.spinneradapterpro;import androidx.appcompat.app.AppCompatActivity;import android.os.Bundle; import android.view.View; import android.widget.AdapterView; import android.widget.Spinn…

怎样的公司称为集团?

集团是什么&#xff1f; 集团的意思是有目的组织起来共同行动的团体。企业集团不具有独立的法人资格。《公司法》中并没有“集团”一说。只有有限责任公司和股份有限公司的提法。有的公司进行多元化经营战略&#xff0c;在多个领域均成立了相应的子公司&#xff0c;这样&#…

【递归算法实践】验证二叉搜索树

目录 1. 递归算法 2. 递归实现验证二叉搜索树 3. 递归解法的实现逻辑 4. 递归实现的实例分析 1. 递归算法 递归是一种通过函数自身调用来解决问题的算法&#xff0c;它可以使代码更加简洁和优雅&#xff0c;同时也能够解决许多复杂的问题。在递归中&#xff0c;函数会不断…

【Windows10下启动RocketMQ报错:找不到或无法加载主类 Files\Java\jdk1.8.0_301\lib\dt.jar】解决方法

Windows10下启动RocketMQ报错&#xff1a;找不到或无法加载主类 一、问题产生二、产生原因三、解决办法 一、问题产生 参考RocketMQ Github官网上的说明&#xff0c;下载rocketmq-all-5.1.3-bin-release.zip&#xff0c;解压配置环境变量后&#xff0c;执行如下命令&#xff1a…

基于2.4G RF开发的无线游戏手柄解决方案

平时喜欢玩游戏的朋友&#xff0c;肯定知道键鼠在某些类型的游戏适配和操作方面&#xff0c;不如手柄。作为一个游戏爱好者&#xff0c;还得配上一个游戏手柄才行。比如动作和格斗、体育游戏&#xff0c;由于手柄更合理的摇杆位置和按键布局&#xff0c;操作起来也是得心应手。…

6.pip简介,第三方库的安装

引言 使用过Visual Studio的小伙伴可能对npm不陌生&#xff0c;没错&#xff0c;pip与npm的功能是一样的。 首先要知道&#xff0c;Python这门语言拥有着丰富的标准库以及先辈们开发的各种功能强大的第三方库。而今天我们主要学习的呢就是关于Python中的包管理工具。它是Pytho…

图像的镜像变换之c++实现(qt + 不调包)

1.基本原理 1.水平镜像变化 设图像的宽度为width&#xff0c;则水平镜像变化的映射关系如下&#xff1a; 2.垂直镜像变化 设图像的宽度为height&#xff0c;则垂直镜像变化的映射关系如下&#xff1a; 2.代码实现&#xff08;代码是我以前自学图像处理时写的&#xff0c;代码很…

VS Code安装使用教程

目录 1. VS Code是什么&#xff1f; 2. VS Code的下载和安装 下载&#xff1a; 安装&#xff1a; 2.2 环境的介绍 3. VS Code配置C/C开发环境 3.1 下载和配置MinGW-w64编译器套件 下载&#xff1a; 配置&#xff1a; 3.2 安装C/C插件 3.3 重启VSCode 4. 在VSCode上编…

如何快速创建学校录取查询系统?

作为一名老师&#xff0c;我深知学生及家长们对于录取情况的关注和期待。因此&#xff0c;学校公布录取情况表是非常重要的一项工作。在这篇文章中&#xff0c;我将分享学校公布录取情况表的步骤和流程&#xff0c;帮助大家更好地了解录取情况。 老教师一般是使用易查分来让家…

坤简炫酷的JQuery轮播图插件

介绍&#xff1a; 找到了一个炫酷的JQuery轮播图插件&#xff0c;只需要配置三四行代码就可以实现很多二维三维炫酷的切换效果。 视频效果及教程&#xff1a; https://www.bilibili.com/video/BV1Fu4y1d776/ 代码&#xff1a; https://github.com/w-x-x-w/AwesomeWeb 使用…

[软件工具][原创]OCR识字找图关键词找图以文搜图工具使用教程

OCR识字找图工具功能简介&#xff1a; 当你有一批图片但是想提取图片里面包含关键词的的图片&#xff0c;以前都是手工肉眼打开去找&#xff0c;其实这个大可不必&#xff0c;现在只需输入关键词&#xff0c;软件会自动搜索所有图片&#xff0c;只要包含指定关键词就会复制或者…

【雕爷学编程】Arduino动手做(199)---8x32位WS2812B全彩屏模块

37款传感器与模块的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&#x…

【12】Git工具 协同工作平台使用教程 【Gitee】【腾讯工蜂】

tips&#xff1a;少量的git安装和使用教程&#xff0c;更多讲快速使用上手Gitee和工蜂平台 一、准备工作 1、下载git Git - Downloads (git-scm.com) 找到对应操作系统&#xff0c;对应版本&#xff0c;对应的位数 下载后根据需求自己安装&#xff0c;然后用git --version验…

从R调用python并将即时输出到cons

我使用R命令从R运行python脚本&#xff1a; system(python test.py)但我的打印报表测试.py在python程序完成之前不要出现在R控制台中。我想查看print语句&#xff0c;因为python程序正在R中运行。我也尝试过sys.stdout.write()&#xff0c;但结果是一样的。非常感谢您的帮助。…

Linux系统中的自旋锁(两幅图清晰说明)

总结&#xff1a; 多CPU下的自旋锁采取的是忙等待&#xff08;原地打转&#xff09;机制&#xff0c;虽然忙等待的线程占用了它所在的cpu&#xff0c;但其他线程仍可放到其他CPU上执行。所以自旋锁上锁和解锁之间的临界区代码要尽量的短&#xff0c;最好不要超过5行&#xff0c…