【DA-CLIP】test.py解读,调用DA-CLIP和IRSDE模型复原计算复原图与GT图SSIM、PSNR、LPIPS

news2025/1/18 6:14:00

文件路径daclip-uir-main/universal-image-restoration/config/daclip-sde/test.py

代码有部分修改

导包

import argparse
import logging
import os.path
import sys
import time
from collections import OrderedDict
import torchvision.utils as tvutils

import numpy as np
import torch
from IPython import embed
import lpips

import options as option
from models import create_model

sys.path.insert(0, "../../")
import open_clip
import utils as util
from data import create_dataloader, create_dataset
from data.util import bgr2ycbcr

注意open_clip使用的是项目里的代码,而非环境里装的那个。data、util、option同样是项目里有的包

声明

#### options
parser = argparse.ArgumentParser()
parser.add_argument("-opt", type=str, default='options/test.yml', help="Path to options YMAL file.")
opt = option.parse(parser.parse_args().opt, is_train=False)

opt = option.dict_to_nonedict(opt)

配置文件 

设置配置文件相对地址options/test.yml

在该配置文件中配置GT和LQ图像文件地址

datasets:
  test1:
   name: Test
   mode: LQGT
   dataroot_GT: C:\Users\86136\Desktop\LQ_test\shadow\GT
   dataroot_LQ: C:\Users\86136\Desktop\LQ_test\shadow\LQ

设置results_root结果地址,每次计算结束这个地址保存要求记录的计算结果

该目录下Test文件夹将保存一张GT一张LQ一张复原图像  。

不设置也会默认在项目内 daclip-uir-main\results\daclip-sde\universal-ir

#### path
path:
  pretrain_model_G: E:\daclip\pretrained\universal-ir.pth
  daclip: E:\daclip\pretrained\daclip_ViT-B-32.pt
  results_root: C:\Users\86136\Desktop\daclip-uir-main\results\daclip-sde\universal-ir
  log: 

 

#### mkdir and logger
util.mkdirs(
    (
        path
        for key, path in opt["path"].items()
        if not key == "experiments_root"
        and "pretrain_model" not in key
        and "resume" not in key
    )
)

# os.system("rm ./result")
# os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result")

 报错执行代码没有删除再创建权限?我把相关os操作注释了,全部保存到result对我影响不大

加载创建数据对

#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt["datasets"].items()):
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)
    logger.info(
        "Number of test images in [{:s}]: {:d}".format(
            dataset_opt["name"], len(test_set)
        )
    )
    test_loaders.append(test_loader)

 自定义包含复原IR-SDE模型的外层类model,参考app.py

# load pretrained model by default
model = create_model(opt)
device = model.device

 加载DA-CLIP、IR-SDE

# clip_model, _preprocess = clip.load("ViT-B/32", device=device)
if opt['path']['daclip'] is not None:
    clip_model, preprocess = open_clip.create_model_from_pretrained('daclip_ViT-B-32', pretrained=opt['path']['daclip'])
else:
    clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
clip_model = clip_model.to(device)

else是直接使用CLIP的ViT-B-32模型进行测试的代码。与我测DA-CLIP无关。

想使用的话 目测要预先下载对应模型权重并手动修改pretrained为文件地址,否则报错hf无法连接

sde = util.IRSDE(max_sigma=opt["sde"]["max_sigma"], T=opt["sde"]["T"], schedule=opt["sde"]["schedule"], eps=opt["sde"]["eps"], device=device)
sde.set_model(model.model)
lpips_fn = lpips.LPIPS(net='alex').to(device)

scale = opt['degradation']['scale']

加载IR-SDE、LPIPS

如果不指定crop_border后续crop_border=scale

处理并计算


for test_loader in test_loaders:
    test_set_name = test_loader.dataset.opt["name"]  # path opt['']
    logger.info("\nTesting [{:s}]...".format(test_set_name))
    test_start_time = time.time()
    dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
    util.mkdir(dataset_dir)

    test_results = OrderedDict()
    test_results["psnr"] = []
    test_results["ssim"] = []
    test_results["psnr_y"] = []
    test_results["ssim_y"] = []
    test_results["lpips"] = []
    test_times = []

    for i, test_data in enumerate(test_loader):
        single_img_psnr = []
        single_img_ssim = []
        single_img_psnr_y = []
        single_img_ssim_y = []
        need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else True
        img_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        #### input dataset_LQ
        LQ, GT = test_data["LQ"], test_data["GT"]
        img4clip = test_data["LQ_clip"].to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_context, degra_context = clip_model.encode_image(img4clip, control=True)
            image_context = image_context.float()
            degra_context = degra_context.float()

        noisy_state = sde.noise_state(LQ)

        model.feed_data(noisy_state, LQ, GT, text_context=degra_context, image_context=image_context)
        tic = time.time()
        model.test(sde, save_states=False)
        toc = time.time()
        test_times.append(toc - tic)

        visuals = model.get_current_visuals()
        SR_img = visuals["Output"]
        output = util.tensor2img(SR_img.squeeze())  # uint8
        LQ_ = util.tensor2img(visuals["Input"].squeeze())  # uint8
        GT_ = util.tensor2img(visuals["GT"].squeeze())  # uint8
        
        suffix = opt["suffix"]
        if suffix:
            save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png")
        else:
            save_img_path = os.path.join(dataset_dir, img_name + ".png")
        util.save_img(output, save_img_path)

        # remove it if you only want to save output images
        LQ_img_path = os.path.join(dataset_dir, img_name + "_LQ.png")
        GT_img_path = os.path.join(dataset_dir, img_name + "_HQ.png")
        util.save_img(LQ_, LQ_img_path)
        util.save_img(GT_, GT_img_path)

        if need_GT:
            gt_img = GT_ / 255.0
            sr_img = output / 255.0

            crop_border = opt["crop_border"] if opt["crop_border"] else scale
            if crop_border == 0:
                cropped_sr_img = sr_img
                cropped_gt_img = gt_img
            else:
                cropped_sr_img = sr_img[
                    crop_border:-crop_border, crop_border:-crop_border
                ]
                cropped_gt_img = gt_img[
                    crop_border:-crop_border, crop_border:-crop_border
                ]

            psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
            ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
            lp_score = lpips_fn(
                GT.to(device) * 2 - 1, SR_img.to(device) * 2 - 1).squeeze().item()

            test_results["psnr"].append(psnr)
            test_results["ssim"].append(ssim)
            test_results["lpips"].append(lp_score)

            if len(gt_img.shape) == 3:
                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    if crop_border == 0:
                        cropped_sr_img_y = sr_img_y
                        cropped_gt_img_y = gt_img_y
                    else:
                        cropped_sr_img_y = sr_img_y[
                            crop_border:-crop_border, crop_border:-crop_border
                        ]
                        cropped_gt_img_y = gt_img_y[
                            crop_border:-crop_border, crop_border:-crop_border
                        ]
                    psnr_y = util.calculate_psnr(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255
                    )
                    ssim_y = util.calculate_ssim(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255
                    )

                    test_results["psnr_y"].append(psnr_y)
                    test_results["ssim_y"].append(ssim_y)

                    logger.info(
                        "img{:3d}:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}; LPIPS: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.".format(
                            i, img_name, psnr, ssim, lp_score, psnr_y, ssim_y
                        )
                    )
            else:
                logger.info(
                    "img:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(
                        img_name, psnr, ssim
                    )
                )

                test_results["psnr_y"].append(psnr)
                test_results["ssim_y"].append(ssim)
        else:
            logger.info(img_name)


    ave_lpips = sum(test_results["lpips"]) / len(test_results["lpips"])
    ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])
    ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])
    logger.info(
        "----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n".format(
            test_set_name, ave_psnr, ave_ssim
        )
    )
    if test_results["psnr_y"] and test_results["ssim_y"]:
        ave_psnr_y = sum(test_results["psnr_y"]) / len(test_results["psnr_y"])
        ave_ssim_y = sum(test_results["ssim_y"]) / len(test_results["ssim_y"])
        logger.info(
            "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n".format(
                ave_psnr_y, ave_ssim_y
            )
        )

    logger.info(
            "----average LPIPS\t: {:.6f}\n".format(ave_lpips)
        )

    print(f"average test time: {np.mean(test_times):.4f}")

开头往log记录了相应配置文件内容,不需要可以注释。

遍历测试数据集(test_loaders)计算各种评价指标,如峰值信噪比(PSNR)、结构相似性(SSIM)和感知损失(LPIPS)。

在处理过程中,代码首先会创建一个目录来保存测试结果。

然后,对于每个测试图像,代码会加载对应的图像(如果可用),并使用一个名为clip_model的模型对图像进行编码。

接下来,代码会使用一个名为sde的随机微分方程模型和名为model的深度学习模型来处理带有噪声的图像,并生成复原图像(SR_img)。额可能作者拿了以前做超分的代码没改变量名

在这个过程中,text_contextimage_context被用作模型的输入,

图像都会被保存到之前创建的目录中。

此外,代码还会计算并记录每个图像的PSNR、SSIM和LPIPS分数,并在最后打印出这些分数的平均值。 代码中还包含了一些用于图像处理的实用函数,如util.tensor2img用于将张量转换为图像,util.save_img用于保存图像,以及util.calculate_psnrutil.calculate_ssim用于计算PSNR和SSIM分数。psnr_y和ssim_y 不用可以把相关代码注释。

最后,代码还计算了平均测试时间,并将其打印出来。

结果

log处理的单张图像报错的信息 0是该处理的图像排序序号,即正在处理第0张图

24-04-03 17:28:24.697 - INFO: img  0:_MG_2374_no_shadow - PSNR: 27.779773 dB; SSIM: 0.863140; LPIPS: 0.078669; PSNR_Y: 29.135256 dB; SSIM_Y: 0.869278.

 

可以给复原结果图加个后缀方便区分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1569185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

clickhouse 源码编译部署

clickhouse 源码编译部署 版本 21.7.9.7 点击build project,编译工程,经过一定时间(第一次编译可能几个小时,后续再编译,只编译有改动的文件)生成release目录 在cmake-build-release → programs目录下…

CATIA软件 焊点坐标(BiW Welding SpotPoint)导出txt文本的操作方法

通常我们客户给的工件是带焊点球的形式,且可导出焊点坐标。如下图所示的焊点位置标识及坐标,即是CATIA中Automotive BiW Fastening模块下的BiW Welding SpotPoint焊点定义。 遇到这样的形式,要导出焊点坐标txt文本,可按如下图片找…

C语言—用EasyX实现反弹球消砖块游戏

代码效果如下 #undef UNICODE #undef _UNICODE #include<graphics.h> #include<conio.h> #include<time.h> #include<stdio.h>#define width 640 #define high 480 #define brick_num 10int ball_x, ball_y; int ball_vx, ball_vy; int radius; int ba…

Chatgpt掘金之旅—有爱AI商业实战篇|语言翻译|(五)

演示站点&#xff1a; https://ai.uaai.cn 对话模块 官方论坛&#xff1a; www.jingyuai.com 京娱AI 引言 人工智能&#xff08;AI&#xff09;技术作为当今科技创新的前沿领域&#xff0c;为创业者提供了广阔的机会和挑战。随着AI技术的快速发展和应用领域的不断拓展&#xf…

《QT实用小工具·十二》邮件批量发送工具

1、概述 源码放在文章末尾 该项目实现了邮件的批量发送&#xff0c;如下图所示&#xff1a; 项目部分代码如下所示&#xff1a; #ifndef SMTPCLIENT_H #define SMTPCLIENT_H#include <QtGui> #include <QtNetwork> #if (QT_VERSION > QT_VERSION_CHECK(5,0,…

AcWing-游戏

1388. 游戏 - AcWing题库 所需知识&#xff1a;博弈论&#xff0c;区间dp 由于双方都采取最优的策略来取数字&#xff0c;所以结果为确定的&#xff0c;有可能会有多个不同的过程&#xff0c;但是我们只需要关注最终结果就行了。 方法一&#xff1a; 定义dp[i][j] 表示区间…

从“量子”到分子:探索计算的无限可能 | 综述荐读

在2023年年末&#xff0c;两篇划时代的研究报告在《科学》&#xff08;Science&#xff09;杂志上引发了广泛关注。这两篇论文分别来自两个研究小组&#xff0c;它们共同揭示了单氟化钙分子间相互作用的研究成果&#xff0c;成功地在这些分子间创造出了分子量子比特。这一成就不…

AI大模型与网球运动结合的应用场景及案例分析

AI大模型与网球运动结合的未来前景是广阔的&#xff0c;它不仅能够提升运动员的训练和比赛表现&#xff0c;还能改善教练的策略制定、增强观众的观赛体验以及优化网球赛事的管理。以下是几个具体的应用场景&#xff1a; 1. 运动员技能和表现分析 AI大模型可以通过分析高速摄像…

ROS 2边学边练(12)-- 创建一个工作空间

上一篇我们已经接触过工作空间的概念&#xff0c;并简单了解体验了一点构建包、测试包的流程&#xff0c;此篇会深入一点学习工作空间相关内容。 前言 一个工作空间是包含了ROS 2的功能包的目录&#xff08;文件夹&#xff09;&#xff0c;在使用ROS 2之前我们得激活一下目标工…

希尔排序和快排里的小区间优化

希尔排序 希尔排序是插入排序的优化。 当一串数是逆序时&#xff0c;那么每插入一个数&#xff0c;前面的数都会向后面挪动。 那么这是插入排序的时间复杂度&#xff0c;就会达到O(n^2) 希尔排序是对数组里的数进行预排序。 防止插入排序出现最坏的情况。 预排序&#xf…

2014最新AI学法减分交管12123小程序源码最新玩法

2014最新AI学法减分交管12123小程序源码最新玩法利用ChatGPT实现拍照搜题 利用ChatGPT实现拍照搜题 学法减分这个项目是几年之前的项目&#xff0c;老朋友都知道&#xff0c;以前我用Python实现了向量检索&#xff0c;也就是当时和大家说到的AI题库&#xff0c;那时候国内还没…

[lesson05]引用的本质分析

引用的本质分析 引用的意义 引用作为变量别名而存在&#xff0c;因此在一些场合可以代替指针 引用相对于指针来说具有更好的可读性和实用性 注意&#xff1a; 函数中的引用形参不需要进行初始化&#xff01;&#xff01; 特殊的引用 const引用 在C中可以声明const引用 cons…

957: 逆置单链表

学习版 【C语言】 #include<iostream> using namespace std; typedef struct LNode {char data;struct LNode* next;LNode(char x) :data(x), next(nullptr) {} }LNode; void creatlist(LNode *&L) {int n;char e;cin >> n;LNode* p1, * p2;p1 L;for (int i…

基于SpringBoot和Vue的教务网络管理系统的设计与实现【附源码】

1、系统演示视频&#xff08;演示视频&#xff09; 2、需要交流和学习请联系

近期全球AI重要资讯

文章目录 1. “免登录”挤爆ChatGPT&#xff0c;百度文心一言们会跟进吗&#xff1f;免登录的便利性行业跟进的可能性对行业的深远影响 2. 开源11天&#xff0c;马斯克再发Grok-1.5&#xff01;128K代码击败GPT-4长语境理解和高级推理代码生成和解决问题能力对开源AI生态的影响…

ES6: class类

类 class 面相对象class关键字创建类关于类的继承 面相对象 一切皆对象。 举例&#xff1a; 操作浏览器要使用window对象&#xff1b;操作网页要使用document对象&#xff1b;操作控制台要使用console对象&#xff1b; ES6中增加了类的概念&#xff0c;其实ES5中已经可以实现类…

43.1k star, 免费开源的 markdown 编辑器 MarkText

43.1k star, 免费开源的 markdown 编辑器 MarkText 分类 开源分享 项目名: MarkText -- 简单而优雅的开源 Markdown 编辑器 Github 开源地址&#xff1a; https://github.com/marktext/marktext 官网地址&#xff1a; MarkText 支持平台&#xff1a; Linux, macOS 以及 Win…

备战蓝桥杯---DP刷题2

1.树形DP&#xff1a; 即问那几个点在树的直径上&#xff0c;类似ROAD那题&#xff0c;我们先求一下每一个子树根的子树的最大值与次大值用d1,d2表示&#xff0c;直径就是d1d2的最大值&#xff0c;那么我们如何判断是否在最大路径上&#xff0c;其实就是看一下从某一点出发的所…

【Qt】:常用控件(四:显示类控件)

常用控件 一.Lable二.LCD Number 一.Lable QLabel 可以⽤来显⽰⽂本和图⽚. 代码⽰例:显⽰不同格式的⽂本 代码⽰例:显⽰图⽚ 此时,如果拖动窗⼝⼤⼩,可以看到图⽚并不会随着窗⼝⼤⼩的改变⽽同步变化 为了解决这个问题,可以在Widget中重写resizeEvent函数。当用户把窗口从A拖…

五、企业级架构之Nginx负载均衡

一、负载均衡技术 1、介绍&#xff1a; 负载均衡技术&#xff08;Load Balance&#xff09;是一种概念&#xff0c;其原理就是把分发流量、请求到不同的服务器&#xff0c;平均分配用户请求。 2、作用&#xff1a; ① 流量分发&#xff0c;请求平均&#xff0c;提高系统处理…