网络剪枝——network-slimming 项目复现

news2024/11/28 16:36:51

目录

文章目录

  • 目录
  • 网络剪枝——network-slimming 项目复现
    • clone 存储库
    • Baseline
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Sparsity
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Prune
      • vgg
        • 命令
        • 结果
      • resnet
        • 命令
        • 结果
      • densenet
        • 命令
        • 结果
    • Fine-tune
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • 模型大小计算脚本 param_counter.py
    • 结果汇总
      • CIFAR10

网络剪枝——network-slimming 项目复现

  • 【GiHnub】:Eric-mingjie/network-slimming: Network Slimming (Pytorch) (ICCV 2017) (github.com)
  • 【作者复现项目】:
  • 通过百度网盘分享的文件:network-slimming-regin.zip
    链接:https://pan.baidu.com/s/1vTJSLS5ZDjE8R8XaApW96A?pwd=t1z2
    提取码:t1z2
    • 仅以 CIFAR-10 为例,CIFAR-100 同理.
    • 提供中文README_zh-CN.md.
    • 包含 CIFAR-10/100 数据集data.cifar10data.cifar100.
    • 解决了 main.py 运行报错问题.
    • 加入了计算训练后模型的 Parameters 大小脚本param_counter.py.

clone 存储库

注:若 clone 作者复现项目,则忽略这一步,直接进入下一步;若想自行从头复现,则 clone 以下存储库.

  • 链接:https://pan.baidu.com/s/1nppPLKoiPbJPW60HOa2TxQ?pwd=ud89
    提取码:ud89


Baseline

vgg

训练
  • 【命令】:
python main.py --dataset cifar10 --arch vgg --depth 19

  • 这个报错通常出现在使用 Python 的multiprocessing库来创建进程时,尤其是在 Windows 操作系统上. 在 Windows 上,Python 的multiprocessing模块启动新进程的方式与 Linux 或 macOS 不同,它使用 “spawn” 来启动新进程,这意味着每个子进程都会从头开始执行脚本. 因此,如果在脚本顶层级别启动进程(而不是在受保护的if __name__ == '__main__':块中),每个子进程都会尝试再次启动子进程,从而导致无限递归和上述错误.
  • 为了解决这个问题,应 确保多进程代码(即main.py)位于if __name__ == '__main__':保护块内.
# 导入部分
...

def main():
    ...


if __name__ == '__main__':
    main()
  • 再次运行命令,又报错:

  • 这个报错通常发生在尝试直接索引一个0维的张量(tensor)时. 在 PyTorch 中,0 维张量是一个单一值的张量,但是不能像普通的数组那样通过索引来访问。要从 0 维张量中获取其 Python 数值,需要使用.item()方法.
  • 为了解决这个问题,应该 使用.item()方法来替换所有.data[0]的用法
# 在 train 函数中
if batch_idx % args.log_interval == 0:
    print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
               100. * batch_idx / len(train_loader), loss.item()))

# 在 test 函数中
for data, target in test_loader:
    if args.cuda:
        data, target = data.cuda(), target.cuda()
    data, target = Variable(data), Variable(target)
    output = model(data)
    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= len(test_loader.dataset)
  • 再次运行命令就正常运行了:

结果
  • Terminal

  • 在 ./logs 生成文件checkpoint.pth.tarmodel_best.pth.tar

resnet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch densenet --depth 40
结果


Sparsity

vgg

训练
  • 【命令】:
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
结果

resnet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40
结果


Prune

vgg

命令
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model ./results/CIFAR10_results/CIFAR10-Vgg/Sparsity/model_best.pth.tar --save ./prunes

  • main.py同理,为了解决这个问题,应 确保多进程代码位于if __name__ == '__main__':保护块内
# 导入部分
...

def main():
    ...


if __name__ == '__main__':
    main()
  • 之后就可以正常运行了.

结果
  • Terminal

  • 在./prunes生成文件prune.txtpruned.pth.tar

  • prune.txt中我们可以看到 Number of parametersTest accuracy

resnet

命令
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Resnet-164/Sparsity/model_best.pth.tar --save ./prunes
结果

densenet

命令
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Densenet-40/Sparsity/model_best.pth.tar --save ./prunes
结果


Fine-tune

vgg

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Vgg/Prune/pruned.pth.tar --dataset cifar10 --arch vgg --depth 19 --epochs 160
结果

resnet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Resnet-164/Prune/pruned.pth.tar --dataset cifar10 --arch resnet --depth 164 --epochs 160
结果

densenet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Densenet-40/Prune/pruned.pth.tar --dataset cifar10 --arch densenet --depth 40 --epochs 160
结果


模型大小计算脚本 param_counter.py

  • 【路径】:./script/param_counter.py
import torch


def load_model(model_path):
    model = torch.load(model_path, map_location=torch.device('cpu'))
    return model


def count_parameters(model_state_dict):
    total_params = sum(p.numel() for p in model_state_dict.values())
    return total_params


def get_model_parameters(model_path):
    # 加载模型状态字典
    model = load_model(model_path)

    # 模型状态字典存储在 'state_dict' 键下
    model_state_dict = model['state_dict'] if 'state_dict' in model else model

    # 计算参数总数
    total_params = count_parameters(model_state_dict)
    return total_params
  • main.py中:
from script.param_counter import get_model_parameters

def main():
    ...
    # 计算 Parameters
    model_path = 'logs/model_best.pth.tar'
    total_params = get_model_parameters(model_path)
    print(f'Total parameters in the model: {total_params}')

结果汇总

注:与原项目结果略有差别.

CIFAR10

CIFAR10-VggBaselineSparsity(1e-4)Prune(70%)Fine-tune-160(70%)
Top1 Accuracy(%)93.7293.6033.9893.75
Parameters20.05M20.05M2.22M2.23M
CIFAR10-Resnet-164BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.9995.0094.5995.27
Parameters1.74M1.74M1.46M1.49M
CIFAR10-Densenet-40BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.1594.3794.1494.48
Parameters1.09M1.09M0.70M0.72M

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2034689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5个小众宝藏软件看看有没有你喜欢的

冷门APP分享来啦,这5个小众宝藏软件看看有没有你喜欢的吧! 1.space登月计划 从地球到月球的大概距离是3.84亿米,而登月得消耗掉大约3.2亿千卡的能量。一个人想单飞登月得花上万年。 但在space上,可以和小伙伴一起合作玩登月游戏…

记录Java使用websocket

实现场景:每在小程序中添加一条数据时,后台将主动推送一个标记给PC端,PC端接收到标记将进行自动播放音频。 import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import or…

GitHub 2FA中国认证教程

1. 问题描述 在github上有过代码贡献的账号在登录时需要进行2FA双重身份验证。 这是github官方给出的关于2FA的解释: 官方文章地址:点击进入 这是登录时2FA的验证界面: 我们需要使用扩展程序解析这个二维码拿到2FA验证码,填入二维…

python爬虫滑块验证及各种加密函数(基于ddddocr进行的一层封装)

git链接: https://github.com/JOUUUSKA/spider_toolsbox 这里写目录标题 一.识别验证码1、识别英文+数字验证码2、识别滑块验证码3、识别点选验证码 一.识别验证码 git链接: https://github.com/JOUUUSKA/spider_toolsbox 创作不易记得stars 1、识别英文&#xf…

Arduino控制带编码器的直流电机速度

Arduino DC Motor Speed Control with Encoder, Arduino DC Motor Encoder 作者 How to control dc motor with encoder:DC Motor with Encoder Arduino, Circuit Diagram:Driving the Motor with Encoder and Arduino:Control DC motor using Encoder feedback loop: How …

一文读懂Xinstall专属链接推广,轻松解决App运营痛点!

随着互联网的飞速发展,App推广和运营面临着前所未有的挑战。传统的营销方式已经难以适应多变的市场环境,而Xinstall专属链接推广应运而生,成为解决App获客难题的新利器。本文将深入探讨Xinstall专属链接推广如何帮助推广者触达更多用户&#…

MacOS vue-cli为2.9.6 无法升级的解决方案

背景 今天需要验证plop工具做前端工程化实践,打算使用vue3方式,结果发现vue-cli 2.9.6一直无法升级成功,也无法通过vue-cli生成vue3模板工程,测试了几把后,最终升级vue-cli成功,为了能给出现同样问题的小伙…

上瘾模型与产品激励系统

​产品要增加客户粘性,使产品深入人心就需要让用户对产品上瘾。如何使用户对产品上瘾?对于产品来说,就需要建立产品的激励系统。 产品的激励系统要做的事就是对用户进行激励,就是让用户主动完成产品或服务想要他们做的事情。 那么…

重启人生计划-勇敢者先行

🥳🥳🥳 茫茫人海千千万万,感谢这一刻你看到了我的文章,感谢观赏,大家好呀,我是最爱吃鱼罐头,大家可以叫鱼罐头呦~🥳🥳🥳 如果你觉得这个【重启人生…

分布式与微服务详解

1. 单机架构 只有一台机器,这个机器负责所有的工作 (这里假定一个电商网站) 现在大部分公司的产品都是单机架构 。 2. 分布式架构 一台机器的硬件资源是有限的,服务器处理请求是需要占用硬件资源的,如果业务增长&a…

前端学习笔记-JS篇-01

JS基础Day1-01-必看-基本软件以及准备工作_哔哩哔哩_bilibili JavaScript介绍 是什么 1.JavaScript (是什么?) 是一种运行在客户端(浏览器)的编程语言,实现人机交互效果2.作用(做什么?) 网页特效(监听用户的一些行为让网页作出对应的反馈)表单验证(针对表单…

streampark-使用记录-备忘

1、重新部署的任务会读历史配置(包括错误配置),即使点击确认了也无效 解决:复制新的任务,修改ckeckpoint 路径(重要) 2、任务启动报错,即使后续把脚本改正确或者复制其他脚本过来执…

什么是 Java?

探索 Java,一种多功能且功能强大的编程语言。释放其构建强大应用程序的潜力。 前言 简单来说,Java 是一种用于开发软件应用程序的面向对象设计的编程语言。截至 2019 年,它是世界上最受欢迎的编程语言,尤其是因为它是开源的&#…

MySQL 的 InnoDB 缓冲池里有什么?--InnoDB存储梳理(二)

文章目录 缓冲池的配置介绍一张表 INNODB_BUFFER_POOL_PAGES字段解释 缓冲池的配置 以下配置的意思,缓冲池在内存中的大小为20M;只有1个缓冲池实例;每一块的大小,插入缓冲占的百分比 # InnoDB 缓存池配置 innodb_buffer_pool_si…

Spring Boot 3.x Web单元测试最佳实践

上一篇:Spring Boot 3.x Rest API统一异常处理最佳实践 下一篇:Spring Boot 3.x Filter实战:记录请求日志 Spring Boot为我们提供了非常便捷的web层Rest API单元测试的API,这种开发能力也是小伙伴必须要掌握的。如何对数据库、中…

【简历】扬州某一本大学:前端秋招简历指导,面试通过率低

注:为保证用户信息安全,姓名和学校等信息已经进行同层次变更,内容部分细节也进行了部分隐藏 简历说明 这是25届一本前端同学的简历。这是一个老牌一本学校,老牌一本定位求职层次,可以从传统的中厂上升到大厂。学历可以…

Pytorch离线文件的快速下载

一、为什么要使用离线方式安装Pytorch 参考我的博客《直接用文件方式安装Cuda版本的Pytorch》可以方便的安装Cuda版本的Pytorch,比较方便快捷。系统重装后,可以快速的重新搭建系统。 二、如何直接下载Pytorch的离线安装文件whl 可以参考这个博客&#…

基于SpringBoot的桂林二手房交易系统的设计与实现---附源码17680

目录 1 绪论 1.1 选题背景与意义 1.2国内外研究现状 1.3论文结构与章节安排 2系统分析 2.1 可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统用例分析 2.4 系统流程分析 2.4.1系统开发流程 2.4.2 用户登录流程 2.4.3 系统操作流程 2.4…

java数字产科管理系统源码,产科业务信息系统源码,产科电子病历系统源码,前端框架:Vue、ElementUI 数 据 库:MySQL8.0.36

数字产科管理系统源码,产科业务信息系统源码,产科电子病历系统源码 数字产科管理系统是一套针对孕产妇的基于流程管控的产科业务信息系统。该系统由门诊系统、住院系统、数据统计模块三部分组成。实现孕产妇围产期一待产一住院的持续化、专业化、电子化…

高性能并行计算面试-核心概念-问题理解

目录 1.什么是并行计算?高性能从哪些方面体现? 2.CPU常见的并行技术 3.GPU并行 4.并发与并行 5.常见的并行计算模型 6.如何评估并行程序的性能? 7.描述Am达尔定律和Gustafson定律,并解释它们对并行计算性能的影响 8.并行计…