把组合损失中的权重设置为可学习参数

news2024/10/10 6:17:34

目前的需求是:有一个模型,准备使用组合损失,其中有2个或者多个损失函数。准备对其进行加权并线性叠加。但想让这些权重进行自我学习,更新迭代成最优加权组合。

目录

1、构建组合损失类

2、调用组合损失类

3、为其构建优化器

4、梯度归零

5、跟新优化器参数

6、结果展示


1、构建组合损失类

每项损失函数可以定义在init里面,这样的话就只需要模型的输出和训练目标。我这里没有这样设置,选择把每项损失值传过来进行线性加权叠加。

# 定义组合损失函数---------------------------------------START
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        # 定义损失函数权重作为可训练参数
        self.w_adv = nn.Parameter(torch.ones(1, requires_grad=True))  # 对抗损失的权重,初始值为0.2 
        self.w_con = nn.Parameter(torch.ones(1, requires_grad=True))  # 内容感知损失的权重,初始值为0.2
        self.w_mse = nn.Parameter(torch.ones(1, requires_grad=True))  # 均方误差损失的权重,初始值为0.2
        self.w_s3im = nn.Parameter(torch.ones(1, requires_grad=True))  # 随机结构相似性损失的权重,初始值为0.2
        self.w_gui = nn.Parameter(torch.ones(1, requires_grad=True))  # 边缘引导损失的权重,初始值为0.2


    def forward(self, loss_adv, loss_con, loss_mse, loss_s3im, loss_gui):
        return self.w_adv*loss_adv + self.w_con*loss_con + self.w_mse*loss_mse + self.w_s3im*loss_s3im + self.w_gui*loss_gui

2、调用组合损失类

在计算组合损失之前,需要初始化类对象。

combinedloss = Loss.CombinedLoss()

unet_loss = self.combinedloss(
                            loss_adv = unet_gan_loss, 
                            loss_con = gen_content_loss, 
                            loss_mse = unet_criterion, 
                            loss_s3im = s3im_loss, 
                            loss_gui = guid_loss)

3、为其构建优化器

最好单独构建优化器,这样我们可以设置与总损失不用的学习率。避免学习率过大导致梯度消失。

self.lr_weight_optimizer = optim.Adam(
            self.combinedloss.parameters(),
            lr = 1e-4,
            betas=(0.9, 0.999)
        )

4、梯度归零

在每次计算总损失之前,需要把每个优化器的梯度归零

self.lr_weight_optimizer.zero_grad()

5、跟新优化器参数

在总损失反向传播之后,需要对优化器的参数进行更新

self.lr_weight_optimizer.step()

6、结果展示

每个权重都会自动更新。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1554364.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HiRoPE、MoDiTalker、RecDiffusion、DreamSalon、InterDreamer、BAMM

本文首发于公众号:机器感知 HiRoPE、MoDiTalker、RecDiffusion、DreamSalon、InterDreamer、BAMM Lift3D: Zero-Shot Lifting of Any 2D Vision Model to 3D In recent years, there has been an explosion of 2D vision models for numerous tasks such as seman…

利用lidar生成深度图

前言 目前,深度图像的获取方法有:激光雷达深度成像法、计算机立体视觉成像、坐标测量机法、莫尔条纹法、结构光法等。针对深度图像的研究重点主要集中在以下几个方面:深度图像的分割技术,深度图像的边缘检测技术,基于…

HarmonyOS实战开发-实现自定义弹窗

介绍 本篇Codelab基于ArkTS的声明式开发范式实现了三种不同的弹窗,第一种直接使用公共组件,后两种使用CustomDialogController实现自定义弹窗,效果如图所示 相关概念 AlertDialog:警告弹窗,可设置文本内容和响应回调…

网络七层模型:理解网络通信的架构(〇)

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

【QT入门】 QListWidget各种常见用法详解之列表模式

往期回顾 【QT入门】 Qt代码创建布局之setLayout使用-CSDN博客 【QT入门】 Qt代码创建布局之多重布局变换与布局删除技巧-CSDN博客 【QT入门】 QTabWidget各种常见用法详解-CSDN博客 【QT入门】 QListWidget各种常见用法详解之列表模式 QListWidget有列表和图标两种显示模式&a…

数据结构刷题篇 之 【力扣二叉树基础OJ】详细讲解(含每道题链接及递归图解)

有没有一起拼用银行卡的,取钱的时候我用,存钱的时候你用 1、相同的树 难度等级:⭐ 直达链接:相同的树 2、单值二叉树 难度等级:⭐ 直达链接:单值二叉树 3、对称二叉树 难度等级:⭐⭐ 直达…

Delphi模式编程

文章目录 Delphi模式编程涉及以下几个关键方面:**设计模式的应用****Delphi特性的利用****实际开发中的实践** Delphi模式编程的实例 Delphi模式编程是指在使用Delphi这一集成开发环境(IDE)和Object Pascal语言进行软件开发时,采用…

vivado 器件编程

生成器件镜像后 , 下一步是将其下载到目标器件。 Vivado IDE 具有内置原生的系统内器件编程功能用于执行此操作。 Vivado Design Suite 和 Vivado Lab Edition 都包含相应的功能 , 支持您连接到包含 1 个或多个 FPGA 或 ACAP 的硬 件, 以…

9、鸿蒙学习-开发及引用静态共享包(API 9)

HAR(Harmony Archive)是静态共享包,可以包含代码、C库、资源和配置文件。通过HAR可以实现多个模块或多个工程共享ArkUI组件、资源等相关代码。HAR不同于HAP,不能独立安装运行在设备上,只能作为应用模块的依赖项被引用。…

使用Docker Compose一键部署前后端分离项目(图文保姆级教程)

一、安装Docker和docker Compose 1.Docker安装 //下载containerd.io包 yum install https://download.docker.com/linux/fedora/30/x86_64/stable/Packages/containerd.io-1.2.6-3.3.fc30.x86_64.rpm //安装依赖项 yum install -y yum-utils device-mapper-persistent-data l…

VTK 9.2.6 源码和VTK Examples 编译 Visual Studio 2022

对于编译 VTK 源码和编译详细的说明: VTK 源码编译: 下载源码: 从 VTK 官方网站或者 GitHub 获取源代码。官网目前最近的9.3.0有问题,见VTK 9.3.0 编译问题 Visual Studio 2022去gitlab上选择9.2.6分支进行clone CMake 配置&…

探索数据结构:链式队与循环队列的模拟、实现与应用

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:数据结构与算法 贝蒂的主页:Betty’s blog 1. 队列的定义 队列(queue)是一种只允许在一端进…

系统分析师-参考模型

前言 网络术语中的参考模型指的是OSI参考模型,由ISO(国际标准化组织)制定的一套普遍适用的规范集合,以使得全球范围的计算机平台可进行开放式通信。 ISO创建了一个有助于开发和理解计算机的通信模型,即开放系统互联OS…

openwrt 编译mysql数据库固件,并调用

前言 openwrt 编译源码mysql数据库,并编写demo调用 一、整体架构设计 作者要做一个项目,没有后端服务,只有一个电脑,需要在电脑上安装mysql服务端。然后在设备上安装mysql客户端。 二、PC安装mysql 1.官网链接 自行百度安装&a…

一文彻底搞懂spring循环依赖

文章目录 1. 什么是循环依赖2. Spring怎么解决循环依赖3. 无法处理的循环依赖 1. 什么是循环依赖 Spring 中的循环依赖是指两个或多个 Bean 之间相互依赖,形成一个循环引用的情况。在 Spring 容器中,循环依赖通常指的是单例(Singleton&#…

使用 Idea 快速搭建 SpringMVC 项目的详细步骤

一、开篇 SpringMVC 是一款当下流行的优秀的 MVC 框架,关于 MVC 的概念、作用、优点等内容介绍,在作者之前的一篇 Chat 《深入理解 MVC 框架原理:自定义 Struts2 框架》中有详细的描述。描述了关于另一款主流 MVC 框架的原理介绍,…

Docker-Container

Docker ①什么是容器②为什么需要容器③容器的生命周期容器 OOM容器异常退出容器暂停 ④容器命令清单总览docker createdocker rundocker psdocker logsdocker attachdocker execdocker startdocker stopdocker restartdocker killdocker topdocker statsdocker container insp…

unrealbuildtool 无法找到,执行 Generate Visual Studio Project 错误

参考链接 Generate cpp project Couldnt find UnrealBuildTool - Pipeline & Plugins / Plugins - Epic Developer Community Forums (unrealengine.com) 错误提示如下图: 解决方案: 打开 UnrealBuildTool,生成解决方案就可以了

学习Fast-LIO系列代码中相关概念理解

目录 一、流形和流形空间(姿态) 1.1 定义 1.2 为什么要有流形? 1.3 流形要满足什么性质? (1) 拓扑同胚 (2) 可微结构 1.4 欧式空间和流形空间的区别和联系? (1) 区别: (2) 联系: 1.5 将姿态定义在流形上比…

【Python实用标准库】argparser使用教程

argparser使用教程 1.介绍2.基本使用3.add_argument() 参数设置4.参考 1.介绍 (一)argparse 模块是 Python 内置的用于命令项选项与参数解析的模块,其用主要在两个方面: 一方面在python文件中可以将算法参数集中放到一起&#x…