第二个=================
但是我发现,都要反量化,因为pytorch是只能支持浮点数的。
https://blog.csdn.net/lai_cheng/article/details/118961420
Pytorch的量化大致分为三种:模型训练完毕后动态量化、模型训练完毕后静态量化、模型训练中开启量化,本文从一个工程项目(Pose Estimation)给大家介绍模型训练后静态量化的过程。
我又提问了
我要在这个上面进行16比特量化的修改,应该怎么修改?【class SNN(nn.Module):
def init(self, tau):
super().init()
self.layer = nn.Sequential(
layer.Flatten(),
layer.Linear(28 * 28, 10, bias=False),
neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
)
def forward(self, x: torch.Tensor):
return self.layer(x)】
=
=
=
=
=
=
=
测试【我将模型测试的部分单独写在一个程序中,应该怎么写?】
import torch
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import time
from main import SNN # 确保从你的 main.py 或其他文件中正确导入 SNN 类和 encoder
from torch.utils.tensorboard import SummaryWriter
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer
#python -m main -tau 2.0 -T 50 -device cuda:0 -b 64 -epochs 3 -data-dir \mnist -opt adam -lr 1e-3 -j 2
def test_model(model_path, data_dir, device='cuda:0', T=50,epoch_test = 3):
start_epoch = 0
out_dir = '.\\out_dir'
writer = SummaryWriter(out_dir, purge_step=start_epoch)
# 加载模型
net = SNN(tau=2.0) # 使用适当的参数初始化你的模型
checkpoint = torch.load(model_path, map_location=device)
net.load_state_dict(checkpoint['net'])
net.to(device)
net.eval()
# 加载测试数据集
test_dataset = torchvision.datasets.MNIST(
root=data_dir,
train=False,
transform=ToTensor(),
download=True
)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
for epoch in range(start_epoch, epoch_test):
# 初始化性能指标
test_loss = 0
test_acc = 0
test_samples = 0
start_time = time.time()
encoder = encoding.PoissonEncoder()
with torch.no_grad():
for img, label in test_loader:
img = img.to(device)
label = label.to(device)
label_onehot = F.one_hot(label, 10).float()
out_fr = 0.
for t in range(T):
encoded_img = encoder(img) # 确保 encoder 已经定义
out_fr += net(encoded_img)
out_fr = out_fr / T
loss = F.mse_loss(out_fr, label_onehot)
test_samples += label.numel()
test_loss += loss.item() * label.numel()
test_acc += (out_fr.argmax(1) == label).float().sum().item()
# 注意:如果你的网络需要在每次迭代后重置状态,请在这里调用重置函数
test_time = time.time() - start_time
test_loss /= test_samples
test_acc /= test_samples
writer.add_scalar('test_loss', test_loss, epoch)
writer.add_scalar('test_acc', test_acc, epoch)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')
print(f'Test completed in {test_time:.2f} seconds.')
if __name__ == '__main__':
model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth' # 模型路径
data_dir = 'data' # 数据集路径
test_model(model_path, data_dir)
Test Loss: 0.0167, Test Accuracy: 0.9198
Test completed in 5.56 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9186
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9185
Test completed in 4.77 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9194
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.72 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9193
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9189
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.76 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000
T=5时候,结果如下T只是影响网络看见 了什么,越长不一定越好,趋于稳定
Test Loss: 0.0205, Test Accuracy: 0.9064
Test completed in 2.04 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9050
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9055
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.0203, Test Accuracy: 0.9080
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9074
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9045
Test completed in 1.37 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9058
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9049
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9063
Test completed in 1.47 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9047
Test completed in 1.35 seconds.
test_samples=10000
量化
import torch
with open('model_params.txt', 'r') as file:
lines = file.readlines()
with open('model_params_quantized.txt', 'w') as file:
for line in lines:
# 去除换行符并按逗号和空格拆分字符串
values = line.strip().split(',')
for val in values:
float_val = float(val.strip())
quantized_val = int(round(float_val * 10000)) # 量化为int32
file.write(f"{quantized_val}\n")
量化后再把数字写入进去
import torch
# 加载原始的checkpoint_max.pth文件
model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
# 读取量化后的数据
with open('model_params_quantized.txt', 'r') as file:
quantized_values = [int(line.strip()) for line in file.readlines()]
# 将量化后的数据写回到模型参数中
index = 0
for name, param in checkpoint['net'].items():
if isinstance(param, torch.Tensor):
numel = param.numel()
quantized_param = torch.tensor(quantized_values[index:index+numel]).view(param.size())
checkpoint['net'][name] = quantized_param
index += numel
# 保存新的checkpoint文件
torch.save(checkpoint, 'logs\\T50_b64_adam_lr0.001\\checkpoint_max_quantized.pth')
model_state_dict = checkpoint['net']
for name, param in model_state_dict.items():
print(f"{name}: {param}")
print(f"{name}: {param.size()}")
量化为int32之后的准确率 下降
Test Loss: 0.1182, Test Accuracy: 0.6758
Test completed in 2.10 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6765
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6789
Test completed in 1.25 seconds.
test_samples=10000
Test Loss: 0.1180, Test Accuracy: 0.6785
Test completed in 1.30 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6755
Test completed in 1.35 seconds.
test_samples=10000
Test completed in 1.39 seconds.
test_samples=10000
Test Loss: 0.1183, Test Accuracy: 0.6800
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.1185, Test Accuracy: 0.6750
Test completed in 1.38 seconds.
test_samples=10000