spikingjelly训练自己的网络---量化 --测试

news2024/12/25 9:09:01

在这里插入图片描述
在这里插入图片描述
第二个=================

在这里插入图片描述
在这里插入图片描述
但是我发现,都要反量化,因为pytorch是只能支持浮点数的。

在这里插入图片描述

https://blog.csdn.net/lai_cheng/article/details/118961420
Pytorch的量化大致分为三种:模型训练完毕后动态量化、模型训练完毕后静态量化、模型训练中开启量化,本文从一个工程项目(Pose Estimation)给大家介绍模型训练后静态量化的过程。

我又提问了
我要在这个上面进行16比特量化的修改,应该怎么修改?【class SNN(nn.Module):
def init(self, tau):
super().init()

    self.layer = nn.Sequential(
        layer.Flatten(),
        layer.Linear(28 * 28, 10, bias=False),
        neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
        )

def forward(self, x: torch.Tensor):
    return self.layer(x)】

在这里插入图片描述
在这里插入图片描述

=

=

=

=

=

=

=
测试【我将模型测试的部分单独写在一个程序中,应该怎么写?】
在这里插入图片描述

import torch
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import time
from main import SNN  # 确保从你的 main.py 或其他文件中正确导入 SNN 类和 encoder
from torch.utils.tensorboard import SummaryWriter

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer

#python -m main -tau 2.0 -T 50 -device cuda:0 -b 64 -epochs 3 -data-dir \mnist -opt adam -lr 1e-3 -j 2

def test_model(model_path, data_dir, device='cuda:0', T=50,epoch_test = 3):
    start_epoch = 0
    out_dir = '.\\out_dir'
    writer = SummaryWriter(out_dir, purge_step=start_epoch)

    # 加载模型
    net = SNN(tau=2.0)  # 使用适当的参数初始化你的模型
    checkpoint = torch.load(model_path, map_location=device)
    net.load_state_dict(checkpoint['net'])
    net.to(device)
    net.eval()

    # 加载测试数据集
    test_dataset = torchvision.datasets.MNIST(
        root=data_dir,
        train=False,
        transform=ToTensor(),
        download=True
    )
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    for epoch in range(start_epoch, epoch_test):  
        # 初始化性能指标
        test_loss = 0
        test_acc = 0
        test_samples = 0
        start_time = time.time()

        encoder = encoding.PoissonEncoder()

        with torch.no_grad():
            for img, label in test_loader:
                img = img.to(device)
                label = label.to(device)
                label_onehot = F.one_hot(label, 10).float()
                out_fr = 0.
                for t in range(T):
                    encoded_img = encoder(img)  # 确保 encoder 已经定义
                    out_fr += net(encoded_img)
                out_fr = out_fr / T
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                # 注意:如果你的网络需要在每次迭代后重置状态,请在这里调用重置函数

        test_time = time.time() - start_time
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')
        print(f'Test completed in {test_time:.2f} seconds.')

if __name__ == '__main__':
    model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth'  # 模型路径
    data_dir = 'data'  # 数据集路径
    test_model(model_path, data_dir)
Test Loss: 0.0167, Test Accuracy: 0.9198
Test completed in 5.56 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9186
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9185
Test completed in 4.77 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9194
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.72 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9193
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9189
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.76 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000

T=5时候,结果如下T只是影响网络看见 了什么,越长不一定越好,趋于稳定

Test Loss: 0.0205, Test Accuracy: 0.9064
Test completed in 2.04 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9050
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9055
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.0203, Test Accuracy: 0.9080
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9074
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9045
Test completed in 1.37 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9058
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9049
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9063
Test completed in 1.47 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9047
Test completed in 1.35 seconds.
test_samples=10000
量化
import torch

with open('model_params.txt', 'r') as file:
    lines = file.readlines()

with open('model_params_quantized.txt', 'w') as file:
    for line in lines:
        # 去除换行符并按逗号和空格拆分字符串
        values = line.strip().split(',')
        for val in values:
            float_val = float(val.strip())
            quantized_val = int(round(float_val * 10000))  # 量化为int32
            file.write(f"{quantized_val}\n")


量化后再把数字写入进去
import torch

# 加载原始的checkpoint_max.pth文件
model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# 读取量化后的数据
with open('model_params_quantized.txt', 'r') as file:
    quantized_values = [int(line.strip()) for line in file.readlines()]

# 将量化后的数据写回到模型参数中
index = 0
for name, param in checkpoint['net'].items():
    if isinstance(param, torch.Tensor):
        numel = param.numel()
        quantized_param = torch.tensor(quantized_values[index:index+numel]).view(param.size())
        checkpoint['net'][name] = quantized_param
        index += numel

# 保存新的checkpoint文件
torch.save(checkpoint, 'logs\\T50_b64_adam_lr0.001\\checkpoint_max_quantized.pth')


model_state_dict = checkpoint['net']
for name, param in model_state_dict.items():
    print(f"{name}: {param}")
    print(f"{name}: {param.size()}")

量化为int32之后的准确率  下降
Test Loss: 0.1182, Test Accuracy: 0.6758
Test completed in 2.10 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6765
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6789
Test completed in 1.25 seconds.
test_samples=10000
Test Loss: 0.1180, Test Accuracy: 0.6785
Test completed in 1.30 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6755
Test completed in 1.35 seconds.
test_samples=10000
Test completed in 1.39 seconds.
test_samples=10000
Test Loss: 0.1183, Test Accuracy: 0.6800
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.1185, Test Accuracy: 0.6750
Test completed in 1.38 seconds.
test_samples=10000

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1580979.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 解决 Process 执行命令行命令报【CreateProcess error=2, 系统找不到指定的文件。】错误问题

目录 问题 问题代码 解决方案 判断操作系统 问题 使用 Process 执行命令行命令时,报 CreateProcess error2, 系统找不到指定的文件。但明明指定的文件是存在的。而且这种错误只在 IDEA 中运行会报错,打包后直接 java -jar 运行就能正常运行&#xf…

国产DSP FT-M6678开发-中断开发

全局中断控制器(CIC) FT-M6678 芯片集成了众多的外设,这些外设都可产生中断事件源,这些中断事件如何被服务取决于用户的特殊应用。在FT-M6678 芯片中,EDMA 和CorePac 都能够为事件服务,为了最大限度的增加系…

vue3第十六节(keep-alive 内置组件)

keep-alive 1、目的 在使用组件时,有时我们需要将组件进行缓存,而不是重新渲染,用以提高性能,避免重复加载DOM,提升用户的体验; keep-alive 组件可以做到这一点,它允许你缓存组件实例&#xf…

家用洗地机哪个牌子好?四大热销机型推荐,值得推荐!

随着科技的进步,洗地机在日常生活中能够帮助人们省时省力地打扫卫生,但市面上出现了各种各样的洗地机,好坏参差不齐,选择一个好品牌的洗地机非常重要,因为它们有着可靠的质量保证。那市面上如此众多的洗地机品牌&#…

Python爬虫之分布式爬虫

分布式爬虫 1.详情介绍 分布式爬虫是指将一个爬虫任务分解成多个子任务,在多个机器上同时执行,从而加快数据的抓取速度和提高系统的可靠性和容错性的技术。 传统的爬虫是在单台机器上运行,一次只能处理一个URL,而分布式爬虫通过将…

关于阿里云centos系统下宝塔面板部署django/中pip install mysqlclient失败问题的大总结/阿里云使用oss长期访问凭证

python版本3.12.0 问题1 解决方案 sudo vim /etc/profile export MYSQLCLIENT_CFLAGS"-I/usr/include/mysql" export MYSQLCLIENT_LDFLAGS"-L/usr/lib64/mysql" Esc退出编辑模式 :wq退出并且保存 问题二 说是找不到 mysql.h头文件 CentOS ‘…

【Python】不会优雅的记日志,你又又Out了!!!

1. 引言 在日常开发中,大家经常使用 print 函数来调试我们写的的代码。然而,随着打印语句数量的增加,由于缺乏行号或函数名称,很难确定输出来自何处。而且随着print语句的增多,调试完代码删除这些信息的时候也比较麻烦…

【动态规划-线性dp】【蓝桥杯备考训练】:乌龟棋、最长上升子序列、最长公共子序列、松散子序列、最大上升子序列和【已更新完成】

目录 1、乌龟棋 2、最长上升子序列 3、最长公共子序列 4、松散子序列 5、最大上升子序列和 1、乌龟棋 小明过生日的时候,爸爸送给他一副乌龟棋当作礼物。 乌龟棋的棋盘只有一行,该行有 N 个格子,每个格子上一个分数(非负整…

植物大战僵尸Python版,附带源码注解

目录 一、实现功能 二、安装环境要求 三、如何开始游戏 四、怎么玩 五、演示 六、部分源码注释 6.1main.py 6.2map.py 6.3Menubar.py 七、自定义 7.1plant.json 7.2zombie.json 一、实现功能 实施植物:向日葵、豌豆射手、壁桃、雪豆射手、樱桃炸弹、三…

笔记-Building Apps with the ABAP RESTful Application Programming Model-Week3

Week3 Unit 1: The Enhanced Business Scenario 本节介绍了将要练习的demo的业务场景,在前两周成果的基础上,也就是只读列表,也可以说是报表APP基础上启用了事务能力,也就是CURD以及自定义业务功能的能力,从创建基本的behavior definition,然后behavior definition proj…

老王讲IT:高级变量类型

IT老王:高级变量类型 目标 列表 元组 字典 字符串 公共方法 变量高级 知识点回顾 Python 中数据类型可以分为 数字型 和 非数字型 数字型 整型 (int) 浮点型(float) 布尔型(bool) 真 True 非 0 数 —— 非零…

阻抗匹配(低频和高频)

一、当信号为低频时 二、当信号为高频时 三、最理想的阻抗要求? 四、为什么射频阻抗基本都是50欧姆(信号源阻抗传输线特征阻抗负载阻抗50欧姆) 综合考虑,射频行业标准选定50欧姆阻抗。

【鸿蒙开发】系统组件Text,Span

Text组件 Text显示一段文本 接口: Text(content?: string | Resource) 参数: 参数名 参数类型 必填 参数描述 content string | Resource 否 文本内容。包含子组件Span时不生效,显示Span内容,并且此时text组件的样式不…

算法四十天-删除排序链表中的重复元素

删除排序链表中的重复元素 题目要求 解题思路 一次遍历 由于给定的链表是排好序的,因此重复的元素在链表中的出现的位置是连续的,因此我们只需要对链表进行一次遍历,就可以删除重复的元素。 具体地,我们从指针cur指向链表的头节…

React路由快速入门:Class组件和函数式组件的使用

1. 介绍 在开始学习React路由之前,先了解一下什么是React路由。React Router是一个为React应用程序提供声明式路由的库。它可以帮助您在应用程序中管理不同的URL,并在这些URL上呈现相应的组件。 2. 安装 要在React应用程序中使用React路由,…

【前沿模型解析】潜在扩散模型 2-3 | 手撕感知图像压缩 基础块 自注意力块

1 注意力机制回顾 同ResNet一样,注意力机制应该也是神经网络最重要的一部分了。 想象一下你在观看一场电影,但你的朋友在给你发短信。虽然你正在专心观看电影,但当你听到手机响起时,你会停下来查看短信,然后这时候电…

C++可变模板参数与包装器(function、bind)

文章目录 一、 可变参数模板1. 概念2. 参数包值的获取 二、 包装器1. 什么是包装器2. 为什么要有包装器3. std::function(1) function概念(2) 利用function解决实例化多份问题(3) 包装器的其他使用场景&…

水资源管理系统:守护生命之源,构建和谐水生态

水资源是维系地球生态平衡和人类社会可持续发展的重要基础。然而,随着人口增长、工业化和城市化的加速,水资源短缺、水质污染和生态破坏等问题日益凸显。在这样的背景下,构建一个全面、高效、智能的水资源管理系统显得尤为迫切和必要。 项目…

Google Cookie意见征求底部弹窗

关于欧盟 Cookie 通知 根据2024年欧盟的《通用数据保护条例》以及其他相关法规,要求google cookie的使用必须征求用户的同意,才能进行收集用户数据信息,因此跨境独立站,如果做欧洲市场,就必须弹出cookie收集数据弹窗&a…

阿赵UE学习笔记——26、动画混合空间

阿赵UE学习笔记目录   大家好,我是阿赵。   继续学习虚幻引擎的使用。之前学习了通过蓝图直接控制动画播放,或者通过动画状态机去控制播放。这次来学习一种比较细致的动画控制播放方式,叫做动画混合空间。 一、使用的情景 假设我们现在需…