去噪扩散隐式模型

news2024/11/13 8:57:38

dataset_name = "datasets/oxford-102-flowers/"

dataset_repetitions = 2 # 数据集重复
num_epochs = 25
image_size = 64 # 模型训练和生成图像的大小
# KID = 内核初始距离
kid_image_size = 75
# 从噪声中逐步“去噪”或“扩散”到最终图像所需的步骤数。
kid_diffusion_steps = 5
# 这个参数可能用于在训练或可视化过程中,展示从噪声到最终图像的扩散过程中的中间步骤数。
# 它帮助理解模型是如何逐步生成图像的。
plot_diffusion_steps = 20
# 这两个参数可能用于控制训练过程中信号(即实际图像内容)与噪声的比例。在扩散模型的训练初期,
# 噪声比例可能很高,随着训练的进行,信号比例逐渐增加。这些参数有助于精细控制这一过程。
min_signal_rate = 0.02
max_signal_rate = 0.95

# 体系结构
embedding_dims = 32
# 用于将位置信息或频率信息嵌入到模型中。它定义了嵌入的最大频率,这可能影响模型捕捉细节的能力。
embedding_max_frequency = 1000.0 # 频率
widths = [32, 64, 96, 128]
block_depth = 2
# 优化
batch_size = 64
# 指数移动平均(EMA)的衰减率。EMA是一种平滑技术,用于在训练过程中跟踪模型参数的平均值
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4

因为我想弄懂stable Diffusion,但是下载不了权重,把整个源码扒出来后,transformer那块很容易理解,百度人工智能后,知道这个模型需要三个东西,第一个就是从噪音中生成图片的模型,就是这个隐式模型,另一个模型是vae,这个一会再说,大致看了下文本提示那个,首先得分词文本数据,之后它是交由transformer编码为语义向量,之后要把图片数据和文本数据传人模型训练吧,我想先把从噪音中生成图片的模型原理搞懂,之后去粗取精,更改stable diffusion,它太大了

diffusion_times = keras.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        # 计算噪声率和信号率:根据扩散时间,通过self.diffusion_schedule计算噪声率和信号率。
        # 噪声率决定了噪声在混合图像中的比例,而信号率则是真实图像内容的比例。在训练开始时,噪声率
        # 接近1,表示图像主要由噪声组成。
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        # 混合图像:将归一化的图像与噪声按计算出的噪声率和信号率进行混合,生成带噪声的图像。
        noisy_images = signal_rates * images + noise_rates * noises 根据随机生成的diffusion_times得到信号率和噪音率,这两个值的平方和是1,之后得到噪音图片

with tf.GradientTape() as tape:
            # train the network to separate noisy images to their components
            # 训练网络分离噪声数据和真实图片数据
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )
            # 计算梯度并更新参数:根据噪声损失(noise_loss)计算网络参数的梯度,并使用优化器(
            # self.optimizer)更新这些参数。
            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric
        # 计算噪声损失对网络可训练参数的梯度,并更新参数
        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

 pred_noises = network([noisy_images, noise_rates**2], training=training)
        # 计算去噪后的图像:利用预测到的噪声成分(pred_noises)、噪声率(noise_rates)和信号率(signal_rates),
        # 通过公式(noisy_images - noise_rates * pred_noises) / signal_rates来计算去噪后的图像(pred_images)
        # 。这个公式是扩散模型逆扩散步骤的核心,它试图从带有噪声的图像中恢复出原始图像。
        # 噪声的预测:pred_noises 是模型根据含噪图像 noisy_images 和当前的噪声率 noise_rates(或其某种变换,如平方
        # )预测出的噪声成分。这个预测是模型在训练过程中学习到的,旨在捕捉含噪图像中的噪声部分
        # 噪声的去除:通过将预测的噪声 pred_noises 乘以相应的噪声率 noise_rates 并从含噪图像 noisy_images 中减去,
        # 我们试图去除图像中的噪声部分。这里的 noise_rates 可能用于调整噪声的强度和尺度,
        # 信号的恢复:然而,仅仅去除噪声并不足以得到完全干净的图像,因为去除噪声后可能会留下信号的衰减部分。因此,我们需要通过除以
        # signal_rates 来恢复信号的原始强度。这里的 signal_rates 代表了在不同噪声水平下信号的保留程度,其值通常与噪声率成反比,
        # 因为随着噪声的增加,信号的可见性会降低
        # signal_rates 这个术语在标准的扩散模型文献中可能并不常见。在扩散模型中,更常见的是使用与噪声方差(或噪声率的平方)相对应的
        # “信号方差”或“剩余方差”,它表示在给定噪声水平下信号(即原始图像内容)的剩余部分。在实际应用中,这个“剩余方差”可能直接用于调整
        # 信号的强度,而不是通过一个显式的 signal_rates 参数。
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates 网络会先下采样,之后上采样,提取特征,得到噪音数据的预测,之后得到预测的图片,之后计算预测和真实的损失,用噪音损失更新模型参数,之所以只用噪音损失计算梯度,很容易理解,因为这个网络就是根据 噪音图片预测噪音数据的,肯定要根据噪音损失计算梯度,根据我一贯的观点,损失在指引模型训练,你仔细琢磨,就发现训练模型就是科学精神,用真实观察和预测间的损失来比较这个理论是否对,所以随着损失下降,这个模型最终应该能成功预测噪音图片中那些数据是噪音数据,自然其他的数据就是真实图片数据,loss=keras.losses.mean_absolute_error这里 损失是mae,评估指标有self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
        self.kid = KID(name="kid") 噪音损失,图片损失,kid,kid是真实图片和生成图片间的核距离吧

经过多个轮次训练后,模型生成图片越来越清晰,说明模型预测越来越准确,它学会了识别噪音数据和真实图片数据

 

 

 

 

 

经过30个轮次后的训练效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2148153.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机毕业设计Python+Flask微博情感分析 微博舆情预测 微博爬虫 微博大数据 舆情分析系统 大数据毕业设计 NLP文本分类 机器学习 深度学习 AI

首先安装需要的python库, 安装完之后利用navicat导入数据库文件bili100.sql到mysql中, 再在pycharm编译器中连接mysql数据库,并在设置文件中将密码修改成你的数据库密码。最后运行app.py,打开链接,即可运行。 B站爬虫数…

VS code 创建与运行 task.json 文件

VS code 创建与运行 task.json 文件 引言正文创建 .json 文件第一步第二步第三步 运行 .json 文件 引言 之前在 VS code EXPLORER 中不显示指定文件及文件夹设置(如.pyc, pycache, .vscode 文件) 一文中我们介绍了 settings.json 文件,这里我…

商业终端架构技术-未来之窗行业应用跨平台架构

未来之窗行业应用跨平台架构 以下是对未来之窗行业应用跨平台架构中客户端的稳定优势和网页跨平台性质的扩展列举: 一、客户端的稳定优势: 1. 离线可用性 - 即使在没有网络连接的…

Redis的Key的过期策略是怎样实现的?

在学习Redis时,我们知道可以设置Key的过期时间,我们还知道,Redis一大特点–速度快。 那么当Redis中的数据量起来时,如果直接遍历所有的Key,那么对于Key过期时间的校验应该很费时间,那么Redis究竟是怎样做的…

前端vue-插值表达式和v-html的区别

创建vue实例的时候,可以有两种形式。 1.let appnew Vue({}) 2 const appnew Vue({}) 3 el是挂载点,是上面div的id值 4 data中的值可以展示在上面div中 5 v-html标签里面如果有内容,则我们的新内容会把标签里面的内容覆盖掉

2024 vue3入门教程:02 我的第一个vue页面

1.打开src下的App.vue,删除所有的默认代码 2.更换为自己写的代码, 变量msg:可以自定义为其他(建议不要使用vue的关键字) 我的的第一个vue:可以更换为其他自定义文字 3.运行命令两步走 下载依赖 cnpm i…

Java项目实战II基于Java+Spring Boot+MySQL的酒店客房管理系统(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在旅游与酒…

AI助力智慧农田作物病虫害监测,基于YOLOv8全系列【n/s/m/l/x】参数模型开发构建花田作物种植场景下棉花作物常见病虫害检测识别系统

智慧农业是一个很大的应用市场,将当下如火如荼的AI模型技术与现实的农业生产场景相结合能够有效提升生产效率,农作物在整个种植周期中有很多工作需要进行,如:浇水、施肥、除草除虫等等,传统的农业作物种植生产管理周期…

2024java高频面试-数据库相关

前言:趁着年轻,博主准备在拼一把,看能不能挑个可以干到退休的牛马工位!!! 废话不多说,面试真题赶紧一股脑倒进我的脑袋瓜子里吧!!! 事务四大特性&#xff1f…

268页PPT大型集团智慧工厂信息化顶层架构设计(2024版)

智能制造装备是高端制造业的关键,通过整合智能传感、控制、AI等技术,具备了信息感知、分析规划等智能化功能,能显著提升加工质量、效率和降低成本。该装备是先进制造、信息、智能技术的深度融合。其原理主要包括物联网集成、大数据分析与人工…

计算机毕业设计hadoop+spark知网文献论文推荐系统 知识图谱 知网爬虫 知网数据分析 知网大数据 知网可视化 预测系统 大数据毕业设计 机器学习

《HadoopSpark知网文献论文推荐系统》开题报告 一、研究背景及意义 随着互联网技术的迅猛发展和大数据时代的到来,学术文献的数量呈爆炸式增长,用户面临着严重的信息过载问题。如何高效地从海量文献中筛选出用户感兴趣的论文,成为当前学术界…

黑鲨机型“工程固件” 清除nv资源预览 写入以及修复基带解析

黑鲨手机是专门为中国玩家制作是游戏科技手机。液冷散热技术被第一次运用在手机上,为手机散热领域竖立了新的标杆,同时通过“X元素”,运用跑车流线型设计,打造属于黑鲨的设计语言。超旗舰的硬件配置,辨识度极高的外观设计,让黑鲨手机成为了硬核玩家的标配。 黑鲨机型从1…

Web开发:Thymeleaf模板引擎

1. Thymeleaf 简介 Thymeleaf 是一个现代的服务器端模板引擎,用于生成 HTML、XML、JavaScript 和 CSS。它的设计理念是使模板能够自然地在 Web 浏览器中呈现,同时允许动态生成内容。 2. 最佳实践总结 2.1 项目结构和模板组织 保持清晰的目录结构&…

钢铁焦化水泥超低排的原因分析有哪些建议

实施超低排放的原因分析及其建议,朗观视觉小编建议,大家可以从以下几个方面进行阐述: 一、原因分析 环境保护需求: 随着环保意识的增强和环保法规的日益严格,减少大气污染物排放已成为行业发展的必然趋势。钢铁、焦化…

MT8370|MTK8370(Genio 510 )安卓核心板参数介绍

MTK Genio 510 (MT8370)安卓核心板是一款极为先进的高性能平台,专为满足边缘处理、先进多媒体功能及全面的连接需求而设计,适用于多种人工智能(AI)和物联网(IoT)应用场景。它具备多个高分辨率摄像头支持和可联网触摸屏显示,适用于使用多任务高…

Swagger 概念和使用以及遇到的问题

前言 接口文档对于前后端开发人员都十分重要。尤其近几年流行前后端分离后接口文档又变 成重中之重。接口文档固然重要,但是由于项目周期等原因后端人员经常出现无法及时更新, 导致前端人员抱怨接口文档和实际情况不一致。 很多人员会抱怨别人写的接口文档不…

一个手机号注册3个抖音号的绿色方法?一个人注册多个抖音号的方法!

下面这是我注册的新账号,显示未实名,在手机号这里显示辅助手机号绑定,手机号绑定这里显示未绑定。如果你需要矩阵,那么,还需要设置好头像,以及介绍,这些都可以正常设置。 再好的方法&#xff0c…

【IPV6从入门到起飞】5-5 IPV6+Home Assistant(HACS商店安装)docker版本安装

IPV6Home Assistant[HACS商店安装]docker版本安装 1 背景2 下载HACS3 安装/启用 HACS4 拓展安装 1 背景 在hass中,是有在线商店供我们下载插件,用于美化hass以及拓展功能,但是在docker版本中,默认是没有的,开启高级模…

【有啥问啥】深入解析:机器学习中的过拟合与欠拟合

深入解析:机器学习中的过拟合与欠拟合 在机器学习中,过拟合(overfitting)和欠拟合(underfitting)是模型性能中常见的两大挑战。它们反映了模型的学习能力与泛化能力的不平衡,直接影响模型在训练…

【machine learning-九-梯度下降】

梯度下降 更加通用的梯度下降算法算法步骤 上一节讲过,随机的寻找w和b使损失最小不是一种合适的方法,梯度下降算法就是解决解决这个问题的,它不仅可以用于线性回归,还可以用于神经网络等深度学习算法,是目前的通用性算…