根据DCT特征训练CNN

news2024/11/20 15:23:58

记录一次改代码的挣扎经历:
        看了几篇关于DCT频域的深度模型文献,尤其是21年FcaNet:基于DCT 的attention model,咱就是说想试试将我模型的输入改为分组的DCT系数,然后就开始下面的波折了。

第一次尝试:

        我直接调用了库函数,然后出现问题了:这个库函数是应用在numpy数组上,得在CPU上处理。

from scipy.fftpack import dct, idct
...
dct_block = dct(dct(block, axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
...
block = idct(idct(dct_block, axis=2, norm='ortho'), axis=3, norm='ortho')    # [B,C,k,k]

第二次尝试:
        好吧,我先把数据调回CPU,处理后,再调回GPU,又有新问题了:这样做(将block从GPU转移至CPU)torch类型张量转换为numpy数组时,torch张量的梯度无法保存。

# 图像分块
...
# 将块转移到 CPU
block_cpu = block.cpu()        # [B,C,k,k]
# 在 CPU 上对块应用 DCT
dct_block_np = dct(dct(block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
# 将结果传输回 GPU
dct_block = torch.from_numpy(dct_block_np).to(image.device)     # [B,C,k,k]

...

# 将块转移到 CPU
dct_block_cpu = dct_block.cpu()
# 在 CPU 上对块应用逆 DCT
block_np = idct(idct(dct_block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')
# 将结果传输回 GPU
block = torch.from_numpy(block_np).to(dct_block.device)    # [B,C,k,k]

 第三次尝试:

        根据报错提醒,我进行以下改进,将block_cpu.numpy -> block_cpu.detach.numpy(),即忽略掉torch类型张量带着的梯度信息,哈哈,这样一改,梯度就丢失了,模型就不能反向传播进行更新训练了。

# 图像分块
...
# 将块转移到 CPU
block_cpu = block.cpu()        # [B,C,k,k]
# 在 CPU 上对块应用 DCT
dct_block_np = dct(dct(block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
# 将结果传输回 GPU
dct_block = torch.from_numpy(dct_block_np).to(image.device)     # [B,C,k,k]

...

# 将块转移到 CPU
dct_block_cpu = dct_block.cpu()
# 在 CPU 上对块应用逆 DCT
block_np = idct(idct(dct_block_cpu.detach.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')
# 将结果传输回 GPU
block = torch.from_numpy(block_np).to(dct_block.device)    # [B,C,k,k]

第四次尝试:
        CPU上库函数不好用,那我自己写(借鉴)DCT变换的函数嘛,DCT就是输入k*k图像关于k*k个余弦基函数的加权和嘛:

 别人写的的8 x 8d的DCT和IDCT的实现:


class DCT8X8(nn.Module):
    """ Discrete Cosine Transformation
    Input:
        image(tensor): batch x height x width
    Output:
        dcp(tensor): batch x height x width
    """

    def __init__(self):
        super(DCT8X8, self).__init__()
        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)

        for x, y, u, v in itertools.product(range(8), repeat=4):
            tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)

        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)

        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
        self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())

    def forward(self, image):
        image = image - 128
        result = self.scale * torch.tensordot(image, self.tensor, dims=2)
        result.view(image.shape)
        return result



class IDCT8X8(nn.Module):
    """ Inverse discrete Cosine Transformation
    Input:
        dcp(tensor): batch x height x width
    Output:
        image(tensor): batch x height x width
    """

    def __init__(self):
        super(IDCT8X8, self).__init__()
        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
        self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
        for x, y, u, v in itertools.product(range(8), repeat=4):
            tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())

    def forward(self, image):
        image = image * self.alpha
        result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
        result.view(image.shape)
        return result

我根据上述改的任意block_size的DCT和IDCT:

class DCTCustom(nn.Module):
    """Customizable Discrete Cosine Transformation
    Input:
        image(tensor): batch x height x width
    Output:
        dct(tensor): batch x height x width
    """

    def __init__(self, input_size=8):
        super(DCTCustom, self).__init__()
        self.input_size = input_size
        tensor = np.zeros((input_size, input_size, input_size, input_size), dtype=np.float32)

        for x, y, u, v in itertools.product(range(input_size), repeat=4):
            tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / (2 * input_size)) * np.cos((2 * y + 1) * v * np.pi / (2 * input_size))

        alpha = np.array([1. / np.sqrt(2)] + [1] * (input_size - 1))

        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
        self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())

    def forward(self, image):
        image = image - 128
        result = self.scale * torch.tensordot(image, self.tensor, dims=2)
        result = result.view(image.shape)  # Corrected line
        return result



class IDCTCustom(nn.Module):
    """ Inverse discrete Cosine Transformation
    Input:
        dcp(tensor): batch x height x width
    Output:
        image(tensor): batch x height x width
    """

    def __init__(self, block_size=8):
        super(IDCTCustom, self).__init__()
        self.block_size = block_size

        # Compute alpha coefficients
        alpha = np.array([1. / np.sqrt(2)] + [1] * (block_size - 1))
        self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())

        # Compute tensor for IDCT
        tensor = np.zeros((block_size, block_size, block_size, block_size), dtype=np.float32)
        for x, y, u, v in itertools.product(range(block_size), repeat=4):
            tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / (2 * block_size)) * np.cos(
                (2 * v + 1) * y * np.pi / (2 * block_size)
            )
        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())

    def forward(self, image):
        if image.shape[-2] % self.block_size != 0 or image.shape[-1] % self.block_size != 0:
            raise ValueError("Input dimensions must be divisible by the block size.")

        # Apply IDCT
        image = image * self.alpha
        result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
        result = result.view(image.shape)
        return result

        不出意外的话,问题又出现了,我对一个torch.ones((2,3,k,k))的张量进行DCT,再IDCT恢复。当k=8时(即block_size=8x8)时,能够完全恢复,但当k!=8(=16、32)时,经IDCT后无法恢复原始输入,懵。

第五次尝试(hh):
        突然!我发现了torch内置的DCT函数!可以再GPU上实现DCT。

torch-dct · PyPI

import torch_dct as dct

# 图像分块    # [B,C,H,W]
    ...        # [B,C,k,k]
    # dct
    block = dct.dct_2d(block)     # [B,C,k,k]

    ...
    # idct
    block = dct.idct_2d(block)        # [B,C,k,k]

 然后又有问题了:
        我的模型开始训练后,我发现我的每个epoch的loss都为NAN...

        然后我打印了DCT输出,发现DCT系数长这个样子,CNN不高兴好好训练吧。

        我们再想想办法将输入数据归一化到范围[0, 1]或[-1, 1]之间,再喂给CNN吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1337695.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SpringCloud】-OpenFeign实战及源码解析、与Ribbon结合

一、背景介绍 二、正文 OpenFeign是什么? OpenFeign(简称Feign)是一个声明式的Web服务客户端,用于简化服务之间的HTTP通信。与Nacos和Ribbon等组件协同,以支持在微服务体系结构中方便地进行服务间的通信&#xff1b…

互联科技:全域托管云赋能百行百业的数字化转型

在这个数字经济时代,云计算技术为企业提供了更加高效的业务管理机会,百行百业加速上云。对比几种云网方案,目前公有云方案存在可控性低、数据暴露风险、个性化需求难以满足、服务受限等问题;私有云方案存在建设成本高、建设周期长…

TCP服务器的演变过程:IO多路复用机制select实现TCP服务器

IO多路复用机制select实现TCP服务器 一、前言二、新增使用API函数2.1、select()函数2.2、FD_*系列函数 三、实现步骤四、完整代码五、TCP客户端5.1、自己实现一个TCP客户端5.2、Windows下可以使用NetAssist的网络助手工具 小结 一、前言 手把手教你从0开始编写TCP服务器程序&a…

文献研读|Prompt窃取与保护综述

本文介绍与「Prompt窃取与保护」相关的几篇工作。 目录 1. Prompt Stealing Attacks Against Text-to-Image Generation Models(PromptStealer)2. Hard Prompts Made Easy: Gradient-Based Discrete Optimization for Prompt Tuning and Discovery&#…

Linux - 记录问题:怎么通过安装包的方式安装gRPC

适用场景 当docker 构建环境不能链接到github 的时候,就可以使用本地构建的方式 完成对应服务的构建需求。 参考案例 使用本地安装包的方式安装 gRPC 注意: 在Docker构建过程中,某些软件包可能会尝试配置时区,这通常需要交互式…

性能优化,让用户体验更加完美(渲染层面)

前言 上一篇我们已经围绕“网络层面”探索页面性能优化的方案,接下来本篇围绕“浏览器渲染层面”继续开展探索。正文开始前,我们思考如下问题: 浏览器渲染页面会经过哪几个关键环节?“渲染层面”的优化从哪几方面着手&#xff1f…

智能三维数据虚拟现实电子沙盘

一、概述 易图讯科技(www.3dgis.top)以大数据、云计算、虚拟现实、物联网、AI等先进技术为支撑,支持高清卫星影像、DEM高程数据、矢量数据、无人机倾斜摄像、BIM模型、点云、城市白模、等高线、标高点等数据融合和切换,智能三维数…

Git基础学习_p1

文章目录 一、前言二、Git手册学习2.1 Git介绍&前置知识2.2 Git教程2.2.1 导入新项目2.2.2 做更改2.2.3 Git追踪内容而非文件2.2.4 查看项目历史2.2.5 管理分支🔺2.2.6 用Git来协同工作2.2.7 查看历史 三、结尾 一、前言 Git相信大部分从事软件工作的人都听说过…

SadTalker数字人增加视频输出mp4质量精度

最近在用数字人简易方案,看到了sadtalker虽然效果差,但是可以作为一个快速方案,没有安装sd的版本,随便找了个一键安装包 设置如上 使用倒是非常简单,但是出现一个问题,就是输出的mp4都出马赛克了 界面上却…

001、安装 Rust

目录 1. 安装 Rust 2. 安装编译器 Visual Studio Code 3. 更新、卸载、文档命令 4. 结语 1. 安装 Rust 安装 Rust 非常简单,首先进入 Rust官网 ,然后点击右上角的 Install 。 进入 Install 界面, 它会自动识别你当前的操作系统并给你推荐…

自带AI算法的热红外相机

Tofu AIIR 是识别跟踪与热红外成像一体化的模组,支持热红外视频下的多类型物体检测、识别、跟踪等功能。 产品支持视频编码、设备管理、目标检测、深度学习识别、跟踪等功能,提供多机版与触控版管理软件,为二次开发提供了丰富的SDK接口和开源…

Xshell——Windows将本地文件上传到Linux服务器

1、scp命令 scp是基于ssh的网络文件传输命令,可以将本地文件或文件夹直接上传到服务器指定位置。命令格式: 上传文件 scp -P port filepath usernameip:TargetPath 上传文件夹 scp -r -P port filepath usernameip:TargetPath -P port:用于指…

Spark的生态系统概览:Spark SQL、Spark Streaming

Apache Spark是一个强大的分布式计算框架,用于大规模数据处理。Spark的生态系统包括多个组件,其中两个重要的组件是Spark SQL和Spark Streaming。本文将深入探讨这两个组件,了解它们的功能、用途以及如何在Spark生态系统中使用它们。 Spark …

Redis连接报错-Could not connect to Redis at 127.0.0.1:6379: Connection refused

进入Redis所在路径,命令行输入redis-cli报错:Could not connect to Redis at 127.0.0.1:6379: Connection refused 解决方法: redis-server redis.conf 连接成功:

装饰模式(单一责任)

Decorator(装饰模式:单一责任模式) 链接:装饰模式实例代码 解析 目的 在某些情况下我们可能会“过度地使用继承来扩展对象的功能”,由于继承为类型引入的静态特质,使得这种扩展方式缺乏灵活性&#xff…

H5调用企业微信扫一扫接口

一、依赖引入 <script src"http://res.wx.qq.com/open/js/jweixin-1.2.0.js"></script><!-- <script src"https://res.wx.qq.com/wwopen/js/jsapi/jweixin-1.0.0.js"></script> --><script src"https://open.work.…

ASP.NET MVC的5种AuthorizationFilter

一、IAuthorizationFilter 所有的AuthorizationFilter实现了接口IAuthorizationFilter。如下面的代码片断所示&#xff0c;IAuthorizationFilter定义了一个OnAuthorization方法用于实现授权的操作。作为该方法的参数filterContext是一个表示授权上下文的AuthorizationContext对…

探索前端开发趋势:2023年的新兴技术与发展方向

随着科技的不断发展&#xff0c;前端开发领域也在不断演进。本文将详细介绍2023年前端开发的新兴技术和发展趋势&#xff0c;为开发者们指明前端技术的发展方向和面临的挑战。从WebAssembly、PWA到低代码开发&#xff0c;激动人心的全新前景等你探索。 随着科技的快速发展&…

华锐视点为广汽集团打造VR汽车在线展厅,打破地域限制,尽享购车乐趣

随着科技的飞速发展&#xff0c;我们正在进入一个全新的时代——元宇宙时代。元宇宙是一个虚拟的世界&#xff0c;它不仅能够模拟现实世界&#xff0c;还能够创造出现实世界无法实现的事物。而汽车行业作为人类生活的重要组成部分&#xff0c;也在积极探索与元宇宙的融合&#…

SpringBoot3 整合Kafka

官网&#xff1a;https://kafka.apache.org/documentation/ 消息队列-场景 1. 异步 2. 解耦 3. 削峰 4. 缓冲 消息队列-Kafka 1. 消息模式 消息发布订阅模式&#xff0c;MessageQueue中的消息不删除&#xff0c;会记录消费者的偏移量 2. Kafka工作原理 同一个消费者组里的消…