CVPR最牛图像评价算法!

news2024/9/25 3:06:06
本文所涉及所有资源均在  传知代码平台可获取。

目录

概述

一、论文思路

1.多任务学习框架:

2.视觉-语言对应关系:

3.动态损失权重:

4.模型优化和评估:

二、模型介绍

三、详细实现方法

1.图像编码器和语言编码器(Image Encoder and Language Encoder)

2.特征嵌入(Feature Embedding)

3.余弦相似度计算(Cosine Similarity Calculation)

4.联合概率计算(Joint Probability Calculation)

5.边际化(Marginalization)

6.损失函数(Loss Functions)

7.最终损失(Final Loss)

四、复现过程

一.代码结构

二.使用方法

环境设置

训练模型

测试模型

1.准备测试数据:

2.加载预训练模型:

3.运行测试脚本:

自己的思考:

可以改进的地方:

演示效果

核心逻辑

概述

这篇论文提出了一种基于视觉-语言对应关系的盲图像质量评估方法,通过多任务学习利用其他两个辅助任务的知识,来预测没有参考信息的图像质量。设计了一个多任务学习方案,通过计算所有标签组合并计算视觉-文本嵌入的余弦相似性来得到联合概率,从而推断出每个任务的预测结果,并设计了数据损失函数进行优化。

在三个任务——盲图像质量评估、场景分类和失真类型识别的综合实验中,结果表明所提出的方法能够从场景分类和失真类型识别任务中受益,并在多个图像质量评估数据集上超越了现有技术水平。

一、论文思路

这篇论文提出了一种基于视觉-语言对应关系的盲图像质量评估方法(BIQA),通过多任务学习方案来提升BIQA的性能。主要思路可以总结如下:

1.多任务学习框架:

作者提出了一种通用的多任务学习方案,将BIQA、场景分类和失真类型识别三个任务联合起来进行训练。通过这种方式,模型可以从其他任务中获取辅助知识,以提高BIQA的性能。

2.视觉-语言对应关系:

作者利用预训练的对比学习视觉-语言模型(CLIP)来获取图像和文本的嵌入表示。通过计算图像嵌入与所有候选文本嵌入之间的余弦相似度,可以得到三个任务的联合概率分布。

3.动态损失权重:

在多任务学习中,作者采用了一种简单而高效的方法来自动确定每个任务的损失权重。这种动态权重分配有助于模型在训练过程中更好地平衡不同任务的重要性。

4.模型优化和评估:

作者在多个BIQA数据集上进行了实验,结果表明所提出的方法在预测准确性、泛化能力和质量注释调整方面都优于现有的BIQA技术。

二、模型介绍

1.任务定义:除了盲图像质量评价BIQA任务外,还定义了场景分类(scene classification)和失真类型识别(distortion type identification)两个辅助任务。

2.数据准备:为现有的IQA数据集补充场景分类和失真类型标签,以便在多任务学习框架下联合训练。

3.视觉-语言表示:使用预训练的对比学习视觉-语言模型(CLIP)来获取图像和文本的嵌入表示。图像通过视觉编码器处理,文本通过语言编码器处理。

4.多任务学习:通过计算图像嵌入与所有候选文本嵌入之间的余弦相似度,得到三个任务的联合概率分布。然后,通过边际化这个联合分布,得到每个任务的边际概率,并进一步将离散的质量等级转换为连续的质量分数。

5.损失函数设计:为BIQA、场景分类和失真类型识别设计了三种类型的损失函数,包括排序损失、二元损失和多类损失,并通过动态权重分配来自动优化这些损失函数。

6.模型优化:在多个IQA数据集上联合优化整个方法,最小化加权损失函数的总和。损失权重根据训练动态自动调整。

7.训练过程:使用AdamW优化器,在多个数据集上训练模型,采用动态调整的学习率和余弦退火策略。

三、详细实现方法

1.图像编码器和语言编码器(Image Encoder and Language Encoder)

2.特征嵌入(Feature Embedding)

3.余弦相似度计算(Cosine Similarity Calculation)

4.联合概率计算(Joint Probability Calculation)

5.边际化(Marginalization)

6.损失函数(Loss Functions)

7.最终损失(Final Loss)

四、复现过程

一.代码结构

1.data文件夹:

这是示例图像文件,供demo代码测试时使用。

2.IQA_Database:

这是一个数据集文件夹,包含了不同的图像质量评估(IQA)数据库,例如 BID, ChallengeDB_release, CSIQ, databaserelase2, kadid10k, koniq-10k。这些数据库用于训练和评估图像质量评估模型。

3.BIQA_benchmark.py:

这是一个benchmark测试脚本,用于在不同的IQA数据库上测试模型的性能。

4.clip_biqa.png:

这是CLIP模型的结构框图。

5.demo.py 和 demo2.py:

这两个文件是演示脚本,展示了如何使用LIQE算法进行图像质量评估.

6.ImageDataset.py 和 ImageDataset2.py:

这些文件定义了图像数据集类,用于加载和处理图像数据,供模型训练和评估使用.

7.LIQE.pt:

这是LIQE模型的预训练权重文件。代码会加载这个文件以使用预训练的模型进行图像质量评估。

8.LIQE.py:

这是主要的LIQE算法实现文件,包含了LIQE算法的核心逻辑。

9.MNL_Loss.py:

这是定义了多类对数损失函数的文件,用于训练图像质量评估模型。

10.OutputSaver.py:

这个文件包含保存模型输出结果的函数,可能用于保存预测结果或中间计算结果。

11.README.md:

这是项目的说明文件,通常包含项目的介绍、安装和使用说明。

12.train_unique_clip_weight.py:

这是用于训练模型的脚本,包含了训练流程的实现。

13.utils.py:

这是包含各种实用函数的文件,可能用于数据预处理、图像操作等。

14.weight_methods.py:

这个文件可能包含了一些与权重处理相关的方法或工具函数。二.使用方法

环境设置

1.安装必要的库:torch 2.1.0,python3

2.下载和解压数据集:下载IQA数据库,并解压到 IQA_Database 文件夹下。

3.修改数据集路径(train_unique_clip_weight.py):

训练模型

1.准备训练数据:

确保 IQA_Database 文件夹中包含了所有需要的训练数据集。

可以根据 ImageDataset.py 和 ImageDataset2.py 文件中的定义来加载和处理图像数据。

2.运行训练脚本:

使用 train_unique_clip_weight.py 进行模型训练。该脚本定义了训练流程,包括数据加载、模型训练、损失计算等步骤。

参数解释:

–data_path:数据集的路径。

–epochs:训练的轮数。

–batch_size:每个批次的图像数量。

–lr:学习率。测试模型

1.准备测试数据:

确保测试图像文件(如 data/6898804586.jpg 和 data/I02_01_03.png)存在于 data 文件夹中。

2.加载预训练模型:

将 LIQE.pt 放置在合适的目录中,并确保代码能够正确加载预训练模型。

3.运行测试脚本:

使用 demo.py 或 demo2.py 进行模型测试,评估图像的质量。

自己的思考:

本文算法取得了很好的效果,且发表在cvpr上,除了算法本身的效果确实很好,而且结合了现在很火的多模态模型CLIP,将CLIP用到了IQA领域,并且结合多任务学习,方法上很新颖;再一个,作者的工作量也很大,为现有的六个质量评价数据集添加了两种标签。

可以改进的地方:

1.退化空间的进一步扩展:

尽管现有的退化空间已经非常大,但可以进一步研究如何通过更多类型和更复杂的退化来扩展这一空间,以更好地模拟真实世界中的复杂情况。2.模型架构优化:

当前的方法主要基于ResNet-50等常见架构,可以尝试使用更复杂或更适合BIQA任务的架构,如更深的神经网络或专门设计的模型,以进一步提高性能。3.对比学习中的噪声处理:

在对比学习过程中,可能存在一些噪声样本(如不同内容但相似质量的样本)。可以研究更有效的噪声处理方法,以进一步提升模型的鲁棒性。

演示效果

训练过程演示:

首先加载csv文件:

开始训练:

demo测试运行结果:

结果说明:

Image1经过LIQE算法后的质量评价结果:图像 #1 是一张曝光不足伪影的人体照片,其感知质量为 1.2373046875,由 LIQE 量化

Image2经过LIQE算法后的质量评价结果:图像 #2 是一张带有模糊伪像的风景照片,其感知质量为 2.8671875,由 LIQE 量化

核心逻辑

LIQE算法的核心逻辑:

class LIQE(nn.Module):
    def __init__(self, ckpt, device):
        super(LIQE, self).__init__()
        self.model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
        checkpoint = torch.load(ckpt, map_location=device)
        self.model.load_state_dict(checkpoint)
        joint_texts = torch.cat(
            [clip.tokenize(f"a photo of a {c} with {d} artifacts, which is of {q} quality") for q, c, d
             in product(qualitys, scenes, dists_map)]).to(device)
        with torch.no_grad():
            self.text_features = self.model.encode_text(joint_texts)
            self.text_features = self.text_features / self.text_features.norm(dim=1, keepdim=True)
        self.step = 32
        self.num_patch = 15
        self.normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        self.device = device

    def forward(self, x):
        x = x.to(self.device)
        batch_size = x.size(0)
        x = self.normalize(x)
        x = x.unfold(2, 224, self.step).unfold(3, 224, self.step).permute(2, 3, 0, 1, 4, 5).reshape(-1, 3, 224, 224)

        sel_step = x.size(0) // self.num_patch
        sel = torch.zeros(self.num_patch)
        for i in range(self.num_patch):
            sel[i] = sel_step * i
        sel = sel.long()
        x = x[sel, ...]

        image_features = self.model.encode_image(x)

        image_features = image_features / image_features.norm(dim=1, keepdim=True)

        logit_scale = self.model.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ self.text_features.t()

        logits_per_image = logits_per_image.view(batch_size, self.num_patch, -1)
        logits_per_image = logits_per_image.mean(1)
        logits_per_image = F.softmax(logits_per_image, dim=1)

        logits_per_image = logits_per_image.view(-1, len(qualitys), len(scenes), len(dists_map))
        logits_quality = logits_per_image.sum(3).sum(2)

        similarity_scene = logits_per_image.sum(3).sum(1)
        similarity_distortion = logits_per_image.sum(1).sum(1)
        distortion_index = similarity_distortion.argmax(dim=1)
        scene_index = similarity_scene.argmax(dim=1)

        scene = scenes[scene_index]
        distortion = dists_map[distortion_index]

        quality = 1 * logits_quality[:, 0] + 2 * logits_quality[:, 1] + 3 * logits_quality[:, 2] + \
                             4 * logits_quality[:, 3] + 5 * logits_quality[:, 4]

        return quality, scene, distortion

if __name__ == '__main__':
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    ckpt = './LIQE.pt'
    liqe = LIQE(ckpt, device)

    x = torch.randn(1,3,512,512).to(device)
    q, s, d = liqe(x)

感觉不错,点击我,立即使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2162229.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

德蒂企鹅PAEDIPROTECT:德国医研力作,专为敏感肌婴幼儿量身打造

新生儿的诞生总是伴随着喜悦,也充满着手忙脚乱,尤其是敏感肌宝宝的皮肤护理。宝宝的皮肤如同初绽的花瓣,皮肤角质层薄而脆弱,容易受到外界刺激物的影响,水分流失快,经常会出现干燥、瘙痒、红斑甚至湿疹等症…

【ARM】AMBA和总线

AMBA AMBA(Advanced Microcontroller Bus Architecture) 总线是由ARM公司提出的一种开放性的片上总线标准,它独立于处理器和工艺技术,具有高速度低功耗等特点。 总线:系统芯片中各个模块之间需要有接口来连接。总线作…

爬虫类Chrome去除前端无限debugger反调试(轻松分析算法)

文章目录 引言方法1(简易抓包或者分析js适用)方法2(解决实际问题-最简单的方法)方法3(解决实际问题-麻烦点也是学会fiddler的一个功能)第一步:熟悉界面的大致功能意思第二步:保存出需要替换的代码,记住保存位置,待会儿要用第三步&…

【Python篇】详细学习 pandas 和 xlrd:从零开始

文章目录 详细学习 pandas 和 xlrd:从零开始前言一、环境准备和安装1.1 安装 pandas 和 xlrd1.2 验证安装 二、pandas 和 xlrd 的基础概念2.1 什么是 pandas?2.2 什么是 xlrd? 三、使用 pandas 读取 Excel 文件3.1 读取 Excel 文件的基础方法…

如何在精益六西格玛项目实践中激励小组成员保持积极性?

在精益六西格玛项目实践中,激励小组成员保持积极性是推动项目成功与持续改进的关键因素。精益六西格玛作为一种集精益生产与六西格玛管理精髓于一体的管理模式,旨在通过流程优化、质量提升及成本降低,实现企业的卓越绩效。然而,这…

《DevOps实践指南》笔记-Part 3

一篇文章显得略长,本文对应第5-6章、附录、认证考试、参考资源等。 前言、第1-2章请参考Part 1,第3-4章内容,请参考Part 2。 持续学习与实验的技术实践 通过以下方式制定有关提高安全性、持续改进和边做边学的制度: 建立公正的…

找不到MFC140.dll无法继续执行代码怎么办,共有6种解决方法

在计算机使用过程中,我们可能会遇到各种问题,其中一种常见的问题是DLL文件丢失。DLL文件是动态链接库文件,它包含了可以被多个程序共享的代码和数据。MFC140.dll就是其中之一。本文将深入分析MFC140.dll丢失的原因,并提供6种有效的…

双亲委派机制SPI

SPI如何破坏双亲委派机制?可根据以下概念一步步深入 什么是双亲委派机制? 双亲委派机制是Java类加载器体系中采用的一种类加载策略,旨在保证类加载的安全性和稳定性。 这一机制规定了类加载的顺序和规则,即当一个类加载器收到类…

解决启动docker desktop报The network name cannot be found的问题

现象 deploying WSL2 distributions ensuring main distro is deployed: checking if main distro is up to date: checking main distro bootstrap version: getting main distro bootstrap version: open \wsl$\docker-desktop\etc\wsl_bootstrap_version: The network name…

基于Springboot+vue实现的Cosplay论坛系统

基于springbootvue实现的Cosplay论坛系统 (源码L文ppt)4-066 2.3 系统功能分析 Cosplay论坛系统中采用了Java的springboot框架进行开发,在数据库上选择MYSQL,在功能上Cosplay论坛系统我划分为了普通用户管理模…

Proteus如何添加数码管

1、打开安装好的Proteus,点击上方菜单栏中的“库”,再选择“从库选取零件”,或者在左侧元件列表中单击鼠标右键,再点击右键菜单中的“从库中挑选”选项。 2、之后在元器件库中,点击类别中的“Optoelectronics”&#…

破解 oklink 网站加密数据(升级版)

大家好!我是炒青椒不放辣,关注我,收看每期的编程干货。 逆向是爬虫工程师进阶必备技能,当我们遇到一个问题时可能会有多种解决途径,而如何做出最高效的抉择又需要经验的积累。本期文章将以实战的方式,带你详细地分析并破解 oklink 网站加密数据 特别声明:本篇文章仅供学…

python脚本程序怎么写更优雅?argparse模块巧妙应用

前言 命令行程序,也称CLI程序,另一个直观的名字是脚本程序,简称脚本,由于没有图形用户界面(GUI),所以脚本程序常见的交互方式有3种: 1、脚本程序中读取环境变量,比如env…

解决小爱音箱连接Windows10蓝牙时,语音控制会中断音乐播放的问题

解决天猫精灵连接Windows10蓝牙时,语音控制会中断音乐播放的问题 解决小爱音箱连接Windows10蓝牙时,浏览器控制音量会中断音乐播放的问题 用小爱音箱当蓝牙音响的时候,遇到个很困扰的问题,每次小爱音箱语音控制的过程中,都会启动…

3.js - 运动曲线

这个球,绕着这个红色的线圈转 代码 import * as THREE from three import { OrbitControls } from three/examples/jsm/controls/OrbitControlslet scene,camera,renderer,controls nulllet moon,earth null// 根据,一系列的点,创建曲线 le…

活动报名丨智源Workshop,从o1出发探索LLM推理与思维链

近期o1模型的发布,预示着AI在处理高度复杂问题上再次迈出一大步。大规模强化学习算法在一个数据极高的训练过程中,教会了模型如何利用其思维链进行富有成效的思考。 北京时间9月19日(本周四)晚7点,智源社区将组织「智源…

响应式布局-媒体查询父级布局容器

1.响应式布局容器 父局作为布局容器,配合自己元素实现变化效果,原理:在不通过屏幕下面吗,通过媒体查询来改变子元素的排列方式和大小,从而实现不同尺寸屏幕下看到不同的效果。 2.响应尺寸布局容器常见宽度划分 手机-…

Vue 中 watch 的使用方法及注意事项

前言 Vue 的 Watch 是一个非常有用的功能,它能够监听 Vue 实例数据的变化并执行相应的操作。本篇文章将详细介绍 Vue Watch 的使用方法和注意事项,让你能够充分利用 Watch 来解决 Vue 开发中的各种问题。 1. Watch 是什么? 1.1 Watch 的作…

[js逆向学习] fastmoss电商网站——店铺排名

逆向目标 网站:https://www.fastmoss.com/shop-marketing/tiktok接口:https://www.fastmoss.com/api/shop/shopList/参数:fm-sign 逆向分析 我们今天要分析的是店铺排名,先分析网络请求,找到目标接口 按照上图操作…

Redis: 特点,优势,与其他产品的区别以及高并发原理

入门Redis概述 1 )选择Redis是因为其高性能 因为 Redis 它数据存储的机制是存在内存中的,减少了传统关系数据库的磁盘IO它是单线程的保证了原子性,它还提供了事务,锁等相关的机制 2 )Redis 环境安装配置 linux 或 d…