SNN demo

news2025/1/20 18:22:34

记录一个同门给的SNN demo,仅供自己参考

1 SNN和ANN代码的差别

SNNANN的深度学习demo还是差一些的,主要有下面几个:

  • 输入差一个时间维度T,比如:在cv中,ANN的输入是:[B, C, W, H],SNN的输入是:[B, T, C, W, H]
    补充
    为什么snn需要多一个时间维度?
    因为相较于ann在做分类后每个神经元可以输出具体的数字(比如在分类问题中这个数字表示概率),但snn每个神经元的输出都是01。解决方法就是那么可以模拟时间步(time steps),让这个前向传播的过程多来几次,最后看哪个神经元输出的1比较多,就作为最终结果(类似于ann里输出的数字最大的那个),在train中和labelloss,在应用中就作为模型对应输出。

  • ANN求梯度时可以直接用backward()SNN由于不可导,需要手写反向传播

  • SNN中涉及神经元的选择问题(比如LIF, IF, SRM神经元等)

  • ANN的输入输出都是具体数值,而SNN的输入输出都是脉冲

  • SNN的数据流传播过程是:spike -> u -> spike ,u指的是膜电压membrane potential

2 SNN demo讲解

2.1 定义模型

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            Linear(784, 800),	
            IF(),
            Linear(800, 10),
            IF()
        )

    def forward(self, x):
        return self.model(x)

2.2 重新定义Linear

由于nn.Linear()这个函数只能是B * CWH(以cv为例,C, W, H是表示特征的),SNN的数据流需要转化成BT * CWH的形式,经过Linear才有意义,所以重新定义了Linear()

class Linear(Layer):
    def __init__(self, in_features: int, out_features: int, bias: bool = False,
                 device=None, dtype=None) -> None:
        super(Linear, self).__init__()
        self.model = nn.Linear(in_features, out_features, bias, device, dtype)


class Layer(nn.Module):
    def __init__(self) -> None:
        super(Layer, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x: input stimuli from pre-synapses in T time steps, shape=[N, T, D], while N is batch size,
        T is time step, D is feature dimension.

        :return: summation of pre-synapses stimuli to post-synapses neurons through synapse efficiency,
        each time step are integrated independently.
        """
        return forward_with_time(self.model, x)


def forward_with_time(model: nn.Module, x: torch.Tensor) -> torch.Tensor:

    batch_size, steps = x.shape[:2]		# x.shape[0-1]
    out = model(x.flatten(0, 1).contiguous())	# [N, T, D] -> [N * T, D]
    return out.view(batch_size, steps, *out.shape[1:])	# 将经过Linear后的数据再还原成[N, T, D]这样的维度

2.3 神经元定义

spike -> u -> spike这样的数据流是在神经元中实现的,我们以IF神经元为例:

class IF(nn.Module):
    def __init__(self, threshold=1., rest=0., surrogate=sigmoid):
        super(IF, self).__init__()
        self.threshold = threshold
        self.rest = rest
        self.surrogate = surrogate.apply

    def forward(self, inputs):
        return self.integrate_fire(inputs)

    def integrate_fire(self, inputs):
        u = 0
        spikes = torch.zeros_like(inputs)
        for i in range(inputs.shape[1]):	# T
            u += inputs[:, i]
            spikes[:, i] = self.surrogate(u - self.threshold)
            u = u * (1 - spikes[:, i]) + self.rest * spikes[:, i]
        return spikes

integrate_fire函数中,我们不妨举这样一个小例子来模拟一下过程:

t = torch.rand(3, 3)
zero_t = torch.zeros_like(t)
print(t)
print(zero_t)
u = 0
for i in range(t.shape[1]):
    print(t[:, i])
    u += t[:, i]
    print(u)  # 单独一个冒号代表从头取到尾

在这里插入图片描述
假设每一列代表一排神经元,那么每一次循环其实就是对一排神经元做处理的过程,循环次数为共有多少列(也就是第一维度时间步T)。当spike作为input输进来时,先影响膜电压u,然后根据u,决定输出什么spike。由于输出了spike,自身也要做调整。上面的过程就是integrate_fire()函数的过程,不同神经元的差别也就在于此。

2.4 代理梯度

代理梯度这里用的是sigmoid

class sigmoid(basic_surrogate):
    @staticmethod
    def backward(ctx, grad_out):
        sgax = (ctx.saved_tensors[0] * ctx.alpha).sigmoid_()
        return grad_out * (1. - sgax) * sgax * ctx.alpha, None      # sigmoid: σ(x), σ'(x) = σ(x)(1-σ(x))

为了用backward还得把forward补齐,因此完整的反向传播代码如下:

def spike_emiting(potential_cond):
    """
    """
    return potential_cond.ge(0.0).to(potential_cond)	# u - threshold > 0 才会 emit spike


class basic_surrogate(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, alpha=4.):		# alpha的作用是改变sigmoid的形状,让它更逼近神经元发放脉冲时的图像
        if inputs.requires_grad:
            ctx.save_for_backward(inputs)
            ctx.alpha = alpha
        return spike_emiting(inputs)


class sigmoid(basic_surrogate):
    @staticmethod
    def backward(ctx, grad_out):
        sgax = (ctx.saved_tensors[0] * ctx.alpha).sigmoid_()
        return grad_out * (1. - sgax) * sgax * ctx.alpha, None      # sigmoid: σ(x), σ'(x) = σ(x)(1-σ(x))

3 SNN demo 完整版

解析看不懂没关系,如果要用的话只需要修改下面几个地方:

  • 输入输出都是spike形式,所以要保证自己的输入是[B, T, D]的形式,D可以是[C, H, W]cv),也可以是其他
  • 神经元选用的是IF神经元,如果要用别的就修改一下2.3integrate_fire()函数
  • 网络结构是两层全连接,修改网络结构的话在2.1下面的代码部分修改
  • 要修改代理梯度的函数,去2.4
  • 要修改其他ANNmodel,去2.2

要我的话可能就改前两个…()
最后奉上完整demo(还没测试过等测试完就把括号里这个划掉)

import torch
import torch.nn as nn


@torch.jit.script
def spike_emiting(potential_cond):
    """
    """
    return potential_cond.ge(0.0).to(potential_cond)


class basic_surrogate(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, alpha=4.):
        if inputs.requires_grad:
            ctx.save_for_backward(inputs)
            ctx.alpha = alpha
        return spike_emiting(inputs)


class sigmoid(basic_surrogate):
    @staticmethod
    def backward(ctx, grad_out):
        sgax = (ctx.saved_tensors[0] * ctx.alpha).sigmoid_()
        return grad_out * (1. - sgax) * sgax * ctx.alpha, None      # sigmoid: σ(x), σ'(x) = σ(x)(1-σ(x))


class IF(nn.Module):
    def __init__(self, threshold=1., rest=0., surrogate=sigmoid):
        super(IF, self).__init__()
        self.threshold = threshold
        self.rest = rest
        self.surrogate = surrogate.apply

    def forward(self, inputs):
        return self.integrate_fire(inputs)

    def integrate_fire(self, inputs):
        u = 0
        spikes = torch.zeros_like(inputs)
        for i in range(inputs.shape[1]):
            u += inputs[:, i]
            spikes[:, i] = self.surrogate(u - self.threshold)
            u = u * (1 - spikes[:, i]) + self.rest * spikes[:, i]
        return spikes


# 由于多一个维度T,在使用torch.nn的层时需要多一步处理,每个t的脉冲要独立加权
def forward_with_time(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
    """
    ..code-block:: python
        B, T = 256, 100
        l1 = nn.Conv2d(1, 16, 3)
        l2 = nn.AvgPool2d(2, 2)
        out1 = forward_with_time(l1, torch.randn(B, T, 1, 28, 28))
        out2 = forward_with_time(l2, out1)

    """
    batch_size, steps = x.shape[:2]
    out = model(x.flatten(0, 1).contiguous())
    return out.view(batch_size, steps, *out.shape[1:])


class Layer(nn.Module):
    def __init__(self) -> None:
        super(Layer, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x: input stimuli from pre-synapses in T time steps, shape=[N, T, D], while N is batch size,
        T is time step, D is feature dimension.

        :return: summation of pre-synapses stimuli to post-synapses neurons through synapse efficiency,
        each time step are integrated independently.
        """
        return forward_with_time(self.model, x)


class Linear(Layer):
    def __init__(self, in_features: int, out_features: int, bias: bool = False,
                 device=None, dtype=None) -> None:
        super(Linear, self).__init__()
        self.model = nn.Linear(in_features, out_features, bias, device, dtype)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            Linear(784, 800),
            IF(),
            Linear(800, 10),
            IF()
        )

    def forward(self, x):
        return self.model(x)



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/427653.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring WebFlow-远程代码执行漏洞(CVE-2017-4971)

Spring WebFlow-远程代码执行漏洞(CVE-2017-4971) 0x00 前言 Spring WebFlow 是一个适用于开发基于流程的应用程序的框架(如购物逻辑),可以将流程的定义和实现流程行为的类和视图分离开来。在其 2.4.x 版本中&#x…

浅说情绪控制被杏仁体劫持

2023年4月16号,没想到被杏仁体劫持那么严重,触发手抖和口干的症状,这个还真是自己万万没有想到的。 人生要修炼两条线:一条明线是做的事情,那是自己要做的具体事情。 一条暗线是修炼的自己,这次也做了测试…

云安全——Docker Daemon

0x00 前言 其他云安全相关内容,请参考:云安全知识整理 0x01 Docker Daemon Daemon是Docker的守护进程,Docker Client通过命令行与Docker Damon通信,完成Docker相关操作,2375端口是Daemon的未授权端口。 0x01 2375 …

设计模式之监听模式

本文将会介绍设计模式中的监听模式。   监听模式是一种一对多的关系,可以有任意个(一个或多个)观察者对象同时监听某一个对象。监听的对象叫观察者(Observer),被监听的对象叫作被观察者(Obser…

QObject对象生命周期管理

QObject对象生命周期管理 1.C中对象的生命周期管理是一个非常重要的话题,因为C需要程序员自己手动管理内存,而这也是C程序经常容易出现内存问题的重要原因。 1.1 特别是多线程环境下如何正确管理好对象的生命周期,更是C程序开发中的一个难点…

对抗样本-(CVPR 2022)-通过基于对象多样化输入来提高有针对性对抗样本的可迁移性

论文地址:https://arxiv.org/abs/2203.09123 代码地址:https://github.com/dreamflake/ODI 摘要:本文提出了一种新的方法来生成有针对性的对抗样本,该方法通过使用多种不同的输入图像来生成更加丰富和多样化的图像。具体而言&…

hashlib模块

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起探讨和分享Linux C/C/Python/Shell编程、机器人技术、机器学习、机器视觉、嵌入式AI相关领域的知识和技术。 hashlib模块专栏:《python从入门到实战》 哈希算法,也叫摘要算法。 加密&…

Postcat 插件上线,支持 ApiPost 格式导入

作为开源的 API 管理工具,Postcat 已经支持 Postman、swagger、Eolink 等平台的数据导入导出。 前不久有用户跟我们提需求,想要 Postcat 支持国内的一些主流的 API 管理工具,好消息是现在不就支持了么! 最近我们的插件广场上线了一…

etcd概述

本文主要介绍了 etcd 相关概念,以及 etcd 的主要使用场景 1. 介绍 etcd 是云原生架构中的基础组件,由 CNCF 孵化托管。etcd 在微服务和 kubernetes 集群中不仅可以作为服务注册中心用于服务发现,还可以作为 key-value 存储中间件etcd 是 Co…

Spring的Bean初始化过程和生命周期

Spring的Bean初始化过程和生命周期一、Spring创建bean的流程图二、Spring创建bean的详细流程1.加载bean信息2.实例化bean3.bean属性填充4.初始化bean5.后置操作三、bean的生命周期四、总结Spring的核心功能有三点IOC、DI、AOP,IOC则是基础,也是Spring功能…

Python+Requests模拟发送post请求

模拟发送post请求 发送post请求的基础知识dumps和loads 代码示例: # 发送post请求 import requests,json # 发送post请求的基础知识dumps和loads str_dict {name:xiaoming,age:20,sex:男} print(type(str_dict)) str1 json.dumps(str_dict) # 1,json.dumps 是把…

git 本地新建并提交上传仓库

初始化步骤基本解释 新建readme touch README.md 初始化仓库 git init 添加仓库下所有文件 git add . 提交 备注到本地 git commit -m "备注" 链接远程git库 git remote add origin 新建库ssh链接 上传代码 git push -u origin master 初始化操作步骤 touch README.…

【Ubuntu】Ubuntu20基础配置+go开发配置

这里写自定义目录标题1 基础配置1.1 安装ifconfig网络管理工具1.2 初始化root密码1.3 换镜像源1.4 关闭息屏休眠1.5 关闭自动更新2 开发环境2.1 go2.1.1 建立软件目录并安装软件2.1.2 建立go工作目录2.1.3 配置环境变量2.2 mysql2.2.1 安装2.2.2 建立对外用户并更改密码2.2.3 修…

江苏三年制专转本法学类考纲配套课程网课题库

江苏三年制专转本法学类考纲配套课程网课题库1、江苏专转本的考试科目都有哪些? 2022年开始江苏专转本成绩主要由语文/数学英语/日语专业课三科的成绩构成,满分500分。分别给大家解释一下 语文/数学:满分150分(文科考语文&#xf…

[源码解析]socket系统调用上

文章目录socket函数API内核源码sock_createinet_createsock_allocsock_map_fd相关数据结构本文将以socket函数为例,分析它在Linux5.12.10内核中的实现,先观此图,宏观上把握它在内核中的函数调用关系:socket函数API socket 函数原…

王小川,才是深「爱」李彦宏的那个人?

在推出中国首个类ChatGPT产品「文心一言」后,李彦宏在接受专访时断言,中国基本不会再出一个OpenAI了,「创业公司重新做一个ChatGPT其实没有多大意义,基于大语言模型开发应用机会很大,没有必要再重新发明一遍轮子。」 听…

【AI理论学习】深入理解扩散模型:Diffusion Models(DDPM)(理论篇)

深入理解扩散模型:Diffusion Models引言扩散模型的原理扩散过程反向过程优化目标模型设计代码实现Stable Diffusion、DALL-E、Imagen背后共同的套路Stable DiffusionDALL-E seriesImagenText encoderDecoder什么是FID(Frechet Inception Distance&#x…

uni-app--》如何实现网上购物小程序(上)?

🏍️作者简介:大家好,我是亦世凡华、渴望知识储备自己的一名在校大学生 🛵个人主页:亦世凡华、 🛺系列专栏:uni-app 🚲座右铭:人生亦可燃烧,亦可腐败&#xf…

YOLOv8源码逐行解读(yolov8.yaml)(更新中)

本人也是刚接触YOLO不久的菜鸟一个,写博客主要是记录自己的学习过程,如果有写的不对的地方欢迎大家批评指正! yolov8.yaml 官方下载地址:https://github.com/ultralytics/ultralytics/tree/main/ultralytics/models/v8 # Ultral…

工业机器人三大主流行业浮动去毛刺应用深度详解

工业机器人是一种能够自动执行各种工业任务的机器人,它的使用不仅能将工人从繁重或有害的体力劳动中解放出来,解决当前劳动力短缺问题,而且能够提高生产效率和产品质量,增强企业整体竞争力,被广泛地应用于工业各个生产…