Denosing score matching (公式+代码)

news2024/11/28 18:53:56

加噪声的分数匹配

在这里插入图片描述

def anneal_dsm_score_estimation(scorenet, samples, labels, sigmas, anneal_power=2.):
    # 取出每个样本对应噪声级别下的噪声分布的标准差,即公式中的sigma_i,
    # 这里的 labels 是用于标识每个样本的噪声级别的,就是 i,实际是一种索引标识
    # (bs,)->(bs,1,1,1) 扩展至与图像一致的维度数
    used_sigmas = sigmas[labels].view(samples.shape[0], *([1] * len(samples.shape[1:])))
    # 加噪:x' = x + sigma * z (z ~ N(0,1))
    perturbed_samples = samples + torch.randn_like(samples) * used_sigmas
    
    # 目标score,本质是对数条件概率密度 log(p(x'|x)) 对噪声数据 x' 的梯度
    # 由于这里建模为高斯分布,因此可计算出结果最终如下,见前文公式(vii)
    target = - 1 / (used_sigmas ** 2) * (perturbed_samples - samples)
    # 模型预测的 score
    scores = scorenet(perturbed_samples, labels)
    target = target.view(target.shape[0], -1)
    scores = scores.view(scores.shape[0], -1)

    # 先计算每个样本在所有维度下分数估计的误差总和,再对所有样本求平均
    # 见前文公式(vii)
    loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1) * used_sigmas.squeeze() ** anneal_power

    return loss.mean(dim=0)

采样生成:

在这里插入图片描述

def anneal_Langevin_dynamics(self, x_mod, scorenet, sigmas, n_steps_each=100, step_lr=0.00002):
    images = []

    with torch.no_grad():
        # 依次在每个噪声级别下进行朗之万动力学采样生成,噪声强度递减
        for c, sigma in tqdm.tqdm(enumerate(sigmas), total=len(sigmas), desc='annealed Langevin dynamics sampling'):
            # 噪声级别
            labels = torch.ones(x_mod.shape[0], device=x_mod.device) * c
            labels = labels.long()

            # 这个步长并非 Algorithm 1 中的 alpha,而是其中第6步的 alpha/2
            # 对应朗之万动力学采样公式(见公式(vi))的 epsilon/2
            step_size = step_lr * (sigma / sigmas[-1]) ** 2
            
            # 每个噪声级别下进行一定步数的朗之万动力学采样生成
            for s in range(n_steps_each):
                images.append(torch.clamp(x_mod, 0.0, 1.0).to('cpu'))
                # 对应公式(vi)最后一项
                noise = torch.randn_like(x_mod) * np.sqrt(step_size * 2)
                # 网络估计的分数
                grad = scorenet(x_mod, labels)
                # 朗之万动力方程
                x_mod = x_mod + step_size * grad + noise

        return images

详细的解释(强推):https://zhuanlan.zhihu.com/p/597490389

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/741800.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaEE——常见的锁策略、CAS、synchronized 原理(八股)

文章目录 一、常见的锁策略1.乐观锁 & 悲观锁2.轻量级锁 & 重量级锁3.自旋锁 & 挂起等待锁4.互斥锁 & 读写锁5. 公平锁 & 非公平锁 二、CAS1、什么是 CAS2. CAS 的应用场景2.实现自旋锁3. CAS 中的 ABA 问题 三、 Synchronized 原理 一、常见的锁策略 当前…

DataFun: ChatGPT背后的模型详解

ChatGPT背后的模型详解 Overview Transofrmer 各个构件都有一定的作用 Multi-head self attention 每个字的重要性不一样,学习QKV三个矩阵(query,key,value) 多组QKV RLHF ChatGPT训练过程 思维链 COT

C++_简单模拟实现string的基本结构

C中,string早于STL问世。使用string中的构造函数可以实现对string类型的字符串的一系列操作。 今天来模拟C中的string的基本结构。注意仅仅是简单模拟,string内部结构其实非常复杂,并且不同版本的IDEstring的内部结构也不尽相同。尽管有所不…

SpringBoot2+Vue2实战(十五)高德地图集成

1.地图官网&#xff1a; 高德开放平台 | 高德地图API 2.开发文档(web js) 正式集成&#xff1a; 1.再index.html中引入script标签 <script type"text/javascript" src"https://webapi.amap.com/maps?v2.0&key您申请的key值"></script>…

第五章 PCIe介绍 5.1-5.7

5.1 从PCIe的速度说起 为什么SSD要用PCIe接口&#xff1f;因为它快&#xff0c;比SATA快。 Lane&#xff1a;通道&#xff0c;PCIe最多可以有32个通道。 1. PCIe的工作模式 两个设备之间的PCIe连接&#xff0c;叫做一个Link。如下图&#xff0c;设备A和设备B是个双向连接&#…

【读书笔记】只管去做

《只管去做》是一本很容易读完的书&#xff0c;这本书是以故事的形式来阐述把愿景落实到每天的行动中的方法&#xff0c;对我们做人生规划很有帮助。

使用leaflet在html中加载天地图且去掉左上角的缩放图标以及右下角的logo

前言 我们这一节使用轻量化的javascript库leaflet来实现在html中加载天地图&#xff0c;实现类似高德地图、百度地图的效果。 效果图如下&#xff1a; 话不多说&#xff0c;进入主题&#xff01;&#xff01; 一、注册开发者权限 我们需要在天地图平台注册一个账号&#xff0…

【Qt】VS2013+QT5.6.3环境搭建

安装VS2013 略 安装Qt 安装文件&#xff1a;qt-opensource-windows-x86-msvc2013-5.6.3.exe&#xff08;官网已经不提供下载了。&#xff09; 安装步骤&#xff1a;安装到C盘根目录&#xff0c;其它略。 安装qt vs插件 1、下载地址&#xff1a; https://download.qt.io/a…

string常见功能模拟

学到string终于就不用像c语言一样造轮子了&#xff0c;接下来我们就模拟一下string方便我们更好理解string&#xff0c;首先我们都知道库里有个string&#xff0c;所以为了避免我们的string和库里的冲突&#xff0c;要用命名空间my_string将我们写的string包含在内。string的成…

精准医学时代:探索人工智能在DCA曲线下的临床医学应用

一、引言 在当今医学领域中&#xff0c;精准医学作为一种以个体差异为基础的医疗模式逐渐受到重视和应用[1]。精准医学基于个体基因组、环境和生活方式因素的综合分析&#xff0c;旨在实现个体化的预防、诊断和治疗方案&#xff0c;从而提供更好的临床结果[2]。与传统医学相比&…

MACD进阶版指标公式,提前一天判断MACD金叉

MACD是一种常用的技术分析指标&#xff0c;用于判断价格的趋势和动能&#xff0c;其原理是基于两条指数移动平均线的比较和对价格的平滑处理&#xff0c;MACD金叉是指MACD指标中的快线DIF从下方向上穿过慢线DEA。快线、慢线都是根据收盘价计算出来的&#xff0c;如果想提前一天…

STM32基础知识点总结

一、基础知识点 1、课程体系介绍 单片机概述arm体系结构STM32开发环境搭建 STM32-GPIO编程-点亮世界的那盏灯 STM32-USART串口应用SPI液晶屏 STM32-中断系统 STM32-时钟系统 STM32-ADC DMA 温湿度传感器-DHT11 2.如何学习单片机课程 多听理论、多理解、有问题及时提问 自己多…

论文阅读:基于深度学习的大尺度遥感图像建筑物分割研究

一、该网络中采用了上下文信息捕获模块。通过扩大感受野&#xff0c;在保留细节信息的同时&#xff0c;在中心部分进行多尺度特征的融合&#xff0c;缓解了传统算法中细节信息丢失的问题&#xff1b;通过自适应地融合局部语义特征&#xff0c;该网络在空间特征和通道特征之间建…

时间序列预测 | Matlab基于粒子群算法(PSO)优化径向基神经网络(PSO-RBF)的时间序列预测

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 时间序列预测| Matlab基于粒子群算法(PSO)优化径向基神经网络(PSO-RBF)的时间序列预测 评价指标包括:MAE、MBE和R2等,代码质量极高,方便学习和替换数据。要求2018版本及以上。 部分源码 %% 清空环境变量 warni…

2023年开放式蓝牙耳机选购指南!多款热门开放式蓝牙耳机品牌盘点

前言 大家好&#xff0c;作为专注耳机研究多年的发烧级爱好者&#xff0c;毫不夸张地说我为耳机花的钱比买衣服还多&#xff0c;很多人都在问我开放式耳机到底有没有必要买&#xff1f;答案毫无疑问是有必要&#xff01;开放式耳机佩戴舒适又安全的特质让它在耳机届风靡&#…

zabbix server is not running错误解决方法

1.错误&#xff1a;zabbix server is not running 打开zabbix server的时候&#xff0c;底部飘着一行黄色的警告字 2.解决方法 (1)关闭selinux (2)查看日志文件 #tail -f /var/log/zabbix/zabbix_server.log 发现内存溢出了 __zbx_mem_realloc(): out of memory 那…

vitepress使用

vitepress使用 初始化项目 pnpm init pnpm add vitepress vue 创建一个docs文件夹 在docs下新建index.js文件 # Hello VitePress在package.json中增加打包以及运行的指令 "scripts": {"docs:dev": "vitepress dev docs", // 本地运行调试&qu…

springboot高校党务系统

开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.3.9

实力认可丨通付盾上榜《嘶吼2023网络安全产业图谱》31项细分领域

7月10日&#xff0c;嘶吼安全产业研究院联合国家网络安全产业园区&#xff08;通州园&#xff09;正式发布《嘶吼2023网络安全产业图谱》。通付盾入围本次图谱的基础技术与通用能力、网络与通信安全、安全服务、应用与产业安全、数据安全、开发与应用安全六大类别&#xff0c;3…