SpikingJelly笔记之梯度替代

news2025/1/8 5:48:41

文章目录

  • 前言
  • 一、梯度替代
  • 二、网络结构
  • 三、MNIST分类
    • 1、单步模式
    • 2、多步模式
  • 总结


前言

在SpikingJelly使用梯度替代训练SNN,构建单层全连接SNN实现MNIST分类任务。


一、梯度替代

1、梯度替代:

阶跃函数不可微,无法进行反向传播

g ( x ) = { 1 , x ≥ 0 0 , x < 0 g(x) = \left\{\begin{matrix} 1,&\quad x\ge 0\\ 0,&\quad x<0\\ \end{matrix}\right. g(x)={1,0,x0x<0 , ,\quad\quad\quad g ′ ( x ) = { + ∞ , x = 0 0 , x ≠ 0 g^{\prime}(x) = \left\{\begin{matrix} +∞&,\quad x= 0\\ 0&,\quad x\neq0\\ \end{matrix}\right. g(x)={+0,x=0,x=0

前向传播使用阶跃函数,反向传播使用替代函数

2、梯度替代函数:

来源:spikingjelly.activation_based.surrogate package

①Sigmoid:surrogate.Sigmoid(alpha=4.0, spiking=True)

g ( x ) = s i g m o i d ( α x ) = 1 1 + e − α x g(x) = sigmoid(\alpha x)=\frac{1}{1+e^{-\alpha x}} g(x)=sigmoid(αx)=1+eαx1

g ′ ( x ) = α ∗ s i g m o i d ( α x ) ∗ ( 1 − s i g m o i d ( α x ) ) g^{\prime}(x) = \alpha*sigmoid(\alpha x)*(1-sigmoid(\alpha x)) g(x)=αsigmoid(αx)(1sigmoid(αx))

②ATan:surrogate.ATan(alpha=2.0, spiking=True)

g ( x ) = 1 π a r c t a n ( π 2 α x ) + 1 2 g(x) = \frac{1}{\pi}arctan(\frac{\pi}{2}\alpha x)+\frac{1}{2} g(x)=π1arctan(2παx)+21

g ′ ( x ) = α 2 ( 1 + ( π 2 α x ) 2 ) g^{\prime}(x) = \frac{\alpha}{2(1+(\frac{\pi}{2}\alpha x)^2)} g(x)=2(1+(2παx)2)α

③SoftSign:surrogate.SoftSign(alpha=2.0, spiking=True)

g ( x ) = 1 2 ( α x 1 + ∣ α x ∣ + 1 ) g(x) = \frac{1}{2}(\frac{\alpha x}{1+|\alpha x|}+1) g(x)=21(1+αxαx+1)

g ′ ( x ) = α 2 ( 1 + ∣ α x ∣ 2 ) g^{\prime}(x) = \frac{\alpha}{2(1+|\alpha x|^2)} g(x)=2(1+αx2)α

④LeakyKReLU:surrogate.LeakyKReLU(spiking=True, leak: float=0.0, k: float=1.0)

g ( x ) = { k ∗ x , x ≥ 0 l e a k ∗ x , x < 0 g(x) = \left\{\begin{matrix} k*x,&\quad x\ge 0\\ leak*x,&\quad x<0\\ \end{matrix}\right. g(x)={kx,leakx,x0x<0 , ,\quad\quad\quad g ′ ( x ) = { k , x ≥ 0 l e a k , x < 0 g^{\prime}(x) = \left\{\begin{matrix} k&,\quad x\ge 0\\ leak&,\quad x<0\\ \end{matrix}\right. g(x)={kleak,x0,x<0

二、网络结构

使用神经元层替代激活函数

1、ANN

nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 10, bias=False),
    nn.Softmax()
    )

2、SNN

nn.Sequential(
    layer.Flatten(),
    layer.Linear(28 * 28, 10, bias=False),
    neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
    )

三、MNIST分类

1、单步模式

(1)导入库

import time
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from spikingjelly.activation_based import neuron, encoding,\
    functional, surrogate, layer, monitor
from spikingjelly import visualizing
from load_mnist import load_mnist

(2)构建数据加载器

将numpy数据封装成DataLoader

使用Pytorch自带的数据集会更方便

def To_loader(x_train, y_train, x_test, y_test, batch_size):
    # 转为张量
    x_train = torch.from_numpy(x_train.astype(np.float32))
    y_train = torch.from_numpy(y_train.astype(np.float32))
    x_test = torch.from_numpy(x_test.astype(np.float32))
    y_test = torch.from_numpy(y_test.astype(np.float32))
    # 数据集封装
    train_dataset = TensorDataset(x_train, y_train)
    test_dataset = TensorDataset(x_test, y_test)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    return train_dataset, test_dataset, train_loader, test_loader

(3)构建SNN模型

将LIF神经元层当作激活函数使用

使用ATan作为梯度替代函数进行反向传播

class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            layer.Linear(784, 10, bias=False),
            neuron.LIFNode(tau=2.0,
                           decay_input=True,
                           v_threshold=1.0,
                           v_reset=0.0,
                           surrogate_function=surrogate.ATan(),
                           step_mode='s',
                           store_v_seq=False)
            )
    def forward(self, x):
        return self.layer(x)

(4)训练参数

使用泊松编码器对输入进行编码

取10000个样本进行训练

epoch_num = 10
batch_size = 256
T = 50
lr = 0.001
encoder = encoding.PoissonEncoder() # 泊松编码器
model = SNN() # 单层SNN
loss_function = nn.MSELoss() # 均方误差
optimizer = optim.Adam(model.parameters(), lr) # Adam优化器
x_train, y_train, x_test, y_test = \
    load_mnist(normalize=True, flatten=False, one_hot_label=True)
train_dataset, test_dataset, train_loader, test_loader =\
    To_loader(x_train[:10000], y_train[:10000], x_test, y_test, batch_size)

(5)迭代训练

①取一段时间的平均发放率作为输出

②损失函数采用交叉熵或均方差,使对应神经元fout→1,其他神经元fout→0

③每批训练后重置网络状态

④每轮训练后测试准确率

start_time = time.time()
loss_train_list = []
acc_train_list = []
acc_test_list = []
for epoch in range(epoch_num):
    print('Epoch:%s'%(epoch+1))
    # 模型训练
    loss_train = 0
    acc_train = 0
    for x, y in train_loader:
        f_out = torch.zeros((y.shape[0], 10)) # 输出频率
        # 前向计算,逐步传播
        for t in range(T):
            encoded_x = encoder(x.reshape(-1, 784))
            f_out += model(encoded_x)
        f_out /= T
        # 反向传播
        loss = loss_function(f_out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算损失值与准确率
        loss_train += loss.item()
        acc_train += (f_out.argmax(1) == y.argmax(1)).sum().item()
        # 清除状态
        functional.reset_net(model)
    acc_train /= len(train_dataset)
    loss_train_list.append(loss_train)
    acc_train_list.append(acc_train)
    print('loss_train:', loss_train)
    print('acc_train:{:.2%}:'.format(acc_train))
    # 模型测试
    with torch.no_grad():
        acc_test = 0
        for x, y in test_loader:
            f_out = torch.zeros((y.shape[0], 10))
            # 逐步传播
            for t in range(T):
                encoded_x = encoder(x.reshape(-1,784))
                f_out += model(encoded_x)
            f_out /= T
            loss = loss_function(f_out, y)
            acc_test += (f_out.argmax(1) == y.argmax(1)).sum().item()
            functional.reset_net(model)
        acc_test /= len(test_dataset)
        acc_test_list.append(acc_test)
        print('acc_test:{:.2%}'.format(acc_test))
end_time = time.time()
print('Time:{:.1f}s'.format(end_time - start_time))

训练结果:

Epoch:10
loss_train: 0.8223596904426813
acc_train:91.10%
acc_test:90.24%
Time:123.3s

(6)显示损失值与准确率变化

fig1 = plt.figure(1, figsize=(12, 6))
ax1 = fig1.add_subplot(2, 2, 1)
ax1.plot(loss_train_list, 'r-')
ax1.set_title('loss_train')
ax2 = fig1.add_subplot(2, 2, 2)
ax2.plot(acc_train_list, 'b-')
ax2.set_title('acc_train')
ax3 = fig1.add_subplot(2, 1, 2)
ax3.plot(acc_test_list, 'b-')
ax3.set_title('acc_test')
plt.show()

训练结果:

(7)结果预测

选取一个数据,观察各神经元的膜电位变化与输出情况

# 设置监视器
for m in model.modules():
    if isinstance(m, neuron.LIFNode):
        m.store_v_seq = True
monitor_o = monitor.OutputMonitor(model, neuron.LIFNode)
monitor_v = monitor.AttributeMonitor('v',
                                      pre_forward=False,
                                      net=model,
                                      instance=neuron.LIFNode)
print('model:', model)
print('monitor_v:', monitor_v.monitored_layers)
print('monitor_o:', monitor_o.monitored_layers)
# 选择一组输入
x, y = test_dataset[0]
f_out = torch.zeros((y.shape[0], 10))
with torch.no_grad():
    # 逐步传播
    for t in range(T):
        encoded_x = encoder(x.reshape(-1,784))
        f_out += model(encoded_x)
    functional.reset_net(model)
    label = y.argmax().item()
    pred = f_out.argmax().item()
print('label:{},predict:{}'.format(label, pred))
# 膜电位与输出可视化
# 膜电位变化
dpi = 100
figsize = (6, 4)
# 合并列表中的张量,删除多余维度,删除梯度信息
v_list = torch.stack(monitor_v['layer.1']).squeeze().detach()
visualizing.plot_2d_heatmap(array=v_list.numpy(),
                            title='Membrane Potentials',
                            xlabel='Simulating Step',
                            ylabel='Neuron Index',
                            int_x_ticks=True,
                            x_max=T,
                            figsize=figsize,
                            dpi=dpi)
# 神经元输出
s_list = torch.stack(monitor_o['layer.1']).squeeze().detach()
visualizing.plot_1d_spikes(spikes=s_list.numpy(),
                            title='Out Spikes',
                            xlabel='Simulating Step',
                            ylabel='Neuron Index',
                            figsize=figsize,
                            dpi=dpi)

预测结果:

model: SNN(
  (layer): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
monitor_v: ['layer.1']
monitor_o: ['layer.1']
label:7,predict:7

膜电位变化:

神经元输出:

2、多步模式

将单步模式改为多步模式,需要修改以下部分:

(1)将神经元层的步进模式由’s’改为’m’

neuron.LIFNode(tau=2.0,
               decay_input=True,
               v_threshold=1.0,
               v_reset=0.0,
               surrogate_function=surrogate.ATan(),
               step_mode='m',
               store_v_seq=False)

(2)一次将所有时间步的数据全部输入

            encoded_x = encoder(x).repeat(T,1,1))
            f_out += model(encoded_x).sum(axis=0)
            f_out /= T

(3)修改监视器监视的变量

monitor_v = monitor.AttributeMonitor('v_seq',
                                      pre_forward=False,
                                      net=model,
                                      instance=neuron.LIFNode)

输出情况:

①训练结果

Epoch:10
loss_train: 0.8167978068813682
acc_train:91.06%:
acc_test:89.78%:
Time:145.1s

②网络结构

model: SNN(
  (layer): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
monitor_v: ['layer.1']
monitor_o: ['layer.1']
label:7,predict:7

③膜电位变化

④神经元输出:


总结

使用梯度替代法进行反向传播时,使用可微的激活函数替代,避免脉冲的不可微;

使用编码器将输入编码为1/0脉冲序列;

将神经元层代替激活函数;

“在正确构建网络的情况下,逐层传播的并行度更大,速度更快”。但在此逐步传播比逐层传播略快一些。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1629843.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动驾驶新书“五一”节马上上市了

我和杨子江教授合写的《自动驾驶系统开发》终于在清华大学出版社三校稿之后即将在五一节后出版。 清华大学汽车学院的李克强教授和工程院院士撰写了序言。 该书得到了唯一华人图灵奖获得者姚期智院士、西安交大管晓宏教授和科学院院士以及杨强教授和院士等的推荐&#xff0c;…

java:SpringBootWeb请求响应

Servlet 用java编写的服务器端程序 客户端发送请求至服务器 服务器启动并调用Servlet,Servlet根据客户端请求生成响应内容并将其传给服务器 服务器将响应返回给客户端 javaweb的工作原理 在SpringBoot进行web程序开发时,内置了一个核心的Servlet程序DispatcherServlet,称之…

前端用a标签实现静态资源文件(excel/word/pdf)下载

接上文实现的 前端实现将二进制文件流&#xff0c;并下载为excel文件后&#xff0c; 实际项目中一般都会有一个模版下载的功能&#xff0c;一般都由服务端提供一个下载接口&#xff0c;返回文件流或url地址&#xff0c;然后前端再处理成对应需要的类型的文件。 但是&#xff…

从3秒飞降至25毫秒:揭秘惊艳的接口优化策略!

大家好&#xff0c;最近看到京东云的一位大佬分享的接口优化方案&#xff0c;感觉挺不错的&#xff0c;拿来即用。建议收藏一波或者整理到自己的笔记本中&#xff0c;随时查阅&#xff01; 下面是正文。 一、背景 针对老项目&#xff0c;去年做了许多降本增效的事情&#xf…

如何学习思考能力?如何训练思考能力?思考不一样?问到底 对新敏感 主动 不怕试错 预测 独特一套

简单易行的方法&#xff1a;问到底 一个简单而有效的方法是使用"五个为什么"技术。当面临问题时&#xff0c;反复问自己为什么&#xff0c;至少问五次&#xff0c;以深入了解问题的根本原因。这有助于培养深入思考和分析问题的能力。 对新敏感 学习思考能力的关键…

PotatoPie 4.0 实验教程(23) —— FPGA实现摄像头图像伽马(Gamma)变换

为什么要进行Gamma校正 图像的 gamma 校正是一种图像处理技术&#xff0c;用于调整图像的亮度和对比度&#xff0c;让显示设备显示的亮度和对比度更符合人眼的感知。Gamma 校正主要用于修正显示设备的非线性响应&#xff0c;以及在图像处理中进行色彩校正和图像增强。 以前&am…

JAVA 中间件之 Mycat2

Mycat2应用与实战教程 1.Mycat2概述 1.1 什么是MyCat 官网&#xff1a; http://mycatone.top/ Mycat 是基于 java 语言编写的数据库中间件&#xff0c;是一个实现了 MySQL 协议的服务器&#xff0c;前端用户可以把它看作是一个数据库代理&#xff0c;用 MySQL 客户端工具和…

1. 房屋租赁管理系统(基于springboot/vue的Java项目)

1.此系统的受众 1.1 在校学习的学生&#xff0c;可用于日常学习使用或是毕业设计使用 1.2 毕业一到两年的开发人员&#xff0c;用于锻炼自己的独立功能模块设计能力&#xff0c;增强代码编写能力。 1.3 亦可以部署为商化项目使用。 2. 技术栈 jdk8springbootvue2mysq5.7&8…

C++ 动态链接库DLL创建及使用

一、动态链接库DLL创建 使用VS2022 创建 1、创建新解决方案 创建即可 2、创建动态链接库新项目 右键解决方案 语言选择C&#xff0c;选择动态链接库 填入项目名称&#xff0c;勾选&#xff1a;将解决方案和项目放在同一目录中 点击创建 3、创建后&#xff0c;显示dllmai…

西湖大学赵世钰老师【强化学习的数学原理】学习笔记2节

强化学习的数学原理是由西湖大学赵世钰老师带来的关于RL理论方面的详细课程&#xff0c;本课程深入浅出地介绍了RL的基础原理&#xff0c;前置技能只需要基础的编程能力、概率论以及一部分的高等数学&#xff0c;你听完之后会在大脑里面清晰的勾勒出RL公式推导链条中的每一个部…

使用frp实现内网穿透教程

文章目录 简介frp 是什么&#xff1f;为什么选择 frp&#xff1f; 概念工作原理代理类型 内网穿透教程服务端安装和配置本地Windows&#xff08;客户端&#xff09;安装和配置本地Linux虚拟机&#xff08;客户端&#xff09;安装和配置使用 systemd 管理服务端注意事项 简介 f…

Odoo:全球排名第一的免费开源PLM管理系统介绍

概述 利用开源智造OdooPLM产品生命周期管理应用&#xff0c;重塑创新 实现产品生命周期管理数字化&#xff0c;高效定义、开发、交付和管理创新的可持续产品&#xff0c;拥抱数字化供应链。 通过开源智造基于Odoo开源技术平台打造数字化的产品生命周期管理&#xff08;PLM&am…

MongoDB的安装(Linux环境)

登录到Linux服务器执行 lsb_release -a &#xff0c;即可得知服务器的版本信息为&#xff1a;CentOS 7。 # CentOS安装lsb_release包 [rootlinux100 ~]# sudo yum install redhat-lsb# 查看Linux版本 [rootlinux100 ~]# lsb_release -a LSB Version: :core-4.1-amd64:core-…

redis7 for windows的安装教程

本篇博客主要介绍redis7的windows版本下的安装教程 1.redis介绍 Redis&#xff08;Remote Dictionary Server&#xff09;是一个开源的&#xff0c;基于内存的数据结构存储系统&#xff0c;可用作数据库、缓存和消息代理。它支持多种数据结构&#xff0c;如字符串、哈希表、列…

Django中的事务

1 开启全局的事务 DATABASES {default: {ENGINE: django.db.backends.mysql, # 使用mysql数据库NAME: tracerbackend, # 要连接的数据库USER: root, # 链接数据库的用于名PASSWORD: 123456, # 链接数据库的用于名HOST: 192.168.1.200, # mysql服务监听的ipPORT: 3306, …

四、OSPF域间路由

注&#xff1a;区域&#xff08;area&#xff09;是以接口进行划分的 描述&#xff1a; R1的g0/0/1接口属于area 0 √ R1属于区域0和区域1 1.设计原则 1、OSPF区域的设计原则&#xff1a; 骨干区域有且只能存在一个 非骨干区域必须和骨干区域相连 多区域时&#…

【Linux实践室】Linux文件打包和解压缩实战指南:tar打包命令操作详解(文末送书)

&#x1f308;个人主页&#xff1a;聆风吟_ &#x1f525;系列专栏&#xff1a;Linux实践室、网络奇遇记 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 一. ⛳️任务描述二. ⛳️相关知识2.1 &#x1f514;打包2.1.1 &#x1f47b;知识点讲解2.1.2 &…

CTFHub Web 信息泄漏(一)

目录遍历 打开题目 点击开始寻找flag 发现在flag_in_here页面中有四个文件夹 点击打开第一个文件夹 发现里面还有四个文件夹 再次点击打开第一个文件夹 里面什么都没有 尝试对所有文件夹依次都点击打开 在2/4中发现flag.txt 点击打开即可得到flag 不太懂这题的难点&#…

碎碎念,最近做了几款小产品...

极简番茄时钟 一款 Mac 版「极简番茄时钟」软件。 知识卡片制作工具 主打简单&#xff0c;同时支持 Markdown 语法。 智能微信助手 让管理变得轻松&#xff0c;沟通更加高效。 感兴趣&#xff0c;欢迎来这里一起交流&#xff0c;限时免费 ~

揭示C++设计模式中的实现结构及应用——行为型设计模式

简介 行为型模式&#xff08;Behavioral Pattern&#xff09;是对在不同的对象之间划分责任和算法的抽象化。 行为型模式不仅仅关注类和对象的结构&#xff0c;而且重点关注它们之间的相互作用。 通过行为型模式&#xff0c;可以更加清晰地划分类与对象的职责&#xff0c;并…