DIF-Gaussian 代码讲解

news2025/1/20 3:38:11

这篇论文的标题是《Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction》,作者是Yiqun Lin, Hualiang Wang, Jixiang Chen和Xiaomeng Li,来自香港科技大学以及HKUST深圳-香港协同创新研究院。

这篇论文主要探讨了一种新的锥束计算机断层扫描(CBCT)重建框架,称为DIF-Gaussian,旨在通过使用更少的投影来减少辐射剂量,同时提高重建图像的质量。

给的代码只是个框架,强行复现花费时间而且以我水平容易误人子弟,就简单的对照论文理解一下,大家有兴趣可以一起讨论

项目地址:

GitHub - xmed-lab/DIF-Gaussian: MICCAI 2024: Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction

数据预处理地址
https://github.com/xmed-lab/C2RV-CBCT/tree/main/data

1、 下载代码和数据预处理方法,数据放到data中

2、发现代码是不完整的,因此边补充边写

train.py

使其与不同版本的DDP兼容

    if args.dist:
        args.local_rank = int(os.environ["LOCAL_RANK"]) # Make it compatible with different versions of DDP
        torch.distributed.init_process_group(backend="nccl")
        torch.cuda.set_device(args.local_rank)

加载cfg,项目只给出了一个default.yaml,复制一个改个名字

    cfg = load_config(args.cfg_path)
    if args.local_rank == 0:
        print(args)
        print(cfg)

        # save config
        save_dir = f'./logs/{args.name}'
        os.makedirs(save_dir, exist_ok=True)
        if os.path.exists(os.path.join(save_dir, 'config.yaml')):
            time_str = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
            shutil.copyfile(
                os.path.join(save_dir, 'config.yaml'), 
                os.path.join(save_dir, f'config_{time_str}.yaml')
            )
        shutil.copyfile(args.cfg_path, os.path.join(save_dir, 'config.yaml'))

初始化训练数据集/加载器

    train_dst = CBCT_dataset_gs(
        dst_name=args.dst_name,
        cfg=cfg.dataset,
        split='train', 
        num_views=args.num_views, 
        npoint=args.num_points,
        out_res_scale=args.out_res_scale,
        random_views=args.random_views
    )

关键在于并没有数据,因此还得自己想办法

dataset:
  root_dir: ../../datasets
  gs_res: 12 # the resolution of GS points (12^3 points in total)

进去看看数据集如何构建

class CBCT_dataset_gs(CBCT_dataset):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        gs_res = self.cfg.gs_res
        points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res
        self.points_gs = points_gs.reshape(3, -1).transpose(1, 0) # ~[0, 1]

    def __getitem__(self, index):
        data_dict = super().__getitem__(index)

        # projections of GS points (initial center xyz)
        points_gs = deepcopy(self.points_gs)
        points_gs_proj = self.project_points(points_gs, data_dict['angles'])

        data_dict.update({
            'points_gs': points_gs,          # [K, 3]
            'points_gs_proj': points_gs_proj # [M, K, 2]
        })
        return data_dict

np.mgrid是NumPy库中的一个函数,它返回一个由给定尺寸的数组创建的多维网格。这段代码points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res创建了一个3D网格,并且将这个网格的每个点归一化到[0, 1]区间。

结果points_gs是一个4D数组,其形状为(gs_res, gs_res, gs_res, 3),其中最后一个维度包含每个网格点的x、y、z坐标。

 看getitem 

points_gs_proj = self.project_points(points_gs, data_dict['angles'])

points_gs 是一个3D网格的点,通常是用于表示3D空间中的一个体素化网格或者用于定义3D空间中的高斯分布的中心点。而 points_gs_proj 则是这些点在2D平面上的投影。

代码是不全的,后期再看看会不会更新

看LUNA16数据预处理的config 内有dataset的参数,其中的angle 为180

get返回一个3d 高斯网格,一个2d的投影

loader如下

    train_sampler = None
    if args.dist:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dst)
    train_loader = DataLoader(
        train_dst, 
        batch_size=args.batch_size, 
        sampler=train_sampler, 
        shuffle=(train_sampler is None),
        num_workers=0, # args.num_workers,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    # -- initialize evaluation dataset/loader
    eval_loader = DataLoader(
        CBCT_dataset_gs(
            dst_name=args.dst_name,
            cfg=cfg.dataset,
            split='eval',
            num_views=args.num_views,
            out_res_scale=0.5, # low-res for faster evaluation,
        ), 
        batch_size=1, 
        shuffle=False,
        pin_memory=True
    )

加载模型,模型放到后面看


    # -- initialize model
    model = DIF_Gaussian(cfg.model)
    if args.resume:
        print(f'resume model from epoch {args.resume}')
        ckpt = torch.load(
            os.path.join(f'./logs/{args.name}/ep_{args.resume}.pth'),
            map_location=torch.device('cpu')
        )
        model.load_state_dict(ckpt)
    
    model = model.cuda()
    if args.dist:
        model = nn.parallel.DistributedDataParallel(
            model, 
            find_unused_parameters=False,
            device_ids=[args.local_rank]
        )

优化器和优化器规划,损失只有一个MSE

    # -- initialize optimizer, lr scheduler, and loss function
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=args.lr, 
        momentum=0.98, 
        weight_decay=args.weight_decay
    )
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=1, 
        gamma=np.power(args.lr_decay, 1 / args.epoch)
    )
    loss_func = nn.MSELoss()

开始训练

    # -- training starts
    for epoch in range(start_epoch, args.epoch + 1):
        if args.dist:
            train_loader.sampler.set_epoch(epoch)

        loss_list = []
        model.train()
        optimizer.zero_grad()

一个epoch,外部看没有花里胡哨的损失,一个损失做到底

        for k, item in enumerate(train_loader):
            item = convert_cuda(item)

            pred = model(item)
            loss = loss_func(pred['points_pred'], item['points_gt'])
            loss_list.append(loss.item())

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

评估和优化

        if args.local_rank == 0:
            if epoch % 10 == 0:
                loss = np.mean(loss_list)
                print('epoch: {}, loss: {:.4}'.format(epoch, loss))
            
            if epoch % 100 == 0 or (epoch >= (args.epoch - 100) and epoch % 10 == 0):
                if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
                    model_state = model.module.state_dict()
                else:
                    model_state = model.state_dict()
                torch.save(
                    model_state,
                    os.path.join(save_dir, f'ep_{epoch}.pth')
                )

            if epoch % 50 == 0 or (epoch >= (args.epoch - 100) and epoch % 20 == 0):
                metrics, _ = eval_one_epoch(
                    model, 
                    eval_loader, 
                    args.eval_npoint,
                    ignore_msg=True,
                )
                msg = f' --- epoch {epoch}'
                for dst_name in metrics.keys():
                    msg += f', {dst_name}'
                    met = metrics[dst_name]
                    for key, val in met.items():
                        msg += ', {}: {:.4}'.format(key, val)
                print(msg)
        
        if lr_scheduler is not None:
            lr_scheduler.step()

model .py

看看初始化定义了什么

class DIF_Gaussian(Recon_base):
    def __init__(self, cfg):
        super().__init__(cfg)

    def init(self):
        self.init_encoder()
        
        # gaussians-related modules
        mid_ch = self.image_encoder.out_ch
        ds_ch = self.image_encoder.ds_ch
        self.gs_feats_mlp = MLP_1d([ds_ch, ds_ch // 4, mid_ch], use_bn=True, last_bn=True, last_act=False)
        self.gs_params_mlp = MLP_1d([ds_ch, ds_ch // 4, 3 + 4 + 3], use_bn=True, last_bn=False, last_act=False) # 3d: offsets, 4d: rotation, 3d: scaling
        self.gs_act = nn.LeakyReLU(inplace=True)

        self.init_decoder(mid_ch * 2)
        self.registered_point_keys = ['points', 'points_proj']

初始化编码器:self.init_encoder()

定义高斯特征和参数mlp:self.gs_feats_mlp;self.gs_params_mlp,选用线性激活self.gs_act

初始化解码器 

虽然没写完全,但是不难想象编码器和解码器的都是unet里面的

看向里面的点forward ,获取点的预测值

1多视图像素对齐功能+最大池

2gaussian-based插值函数

3逐点地预测

class PointDecoder(nn.Module):
    def __init__(self, channels, residual=True, use_bn=True):
        super().__init__()

        self.residual = residual
        self.mlps = nn.ModuleList()

        for i in range(len(channels) - 1):
            modules = []
            if i == 0 or not self.residual:
                modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
            else:
                modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))

            if i != len(channels) - 1:
                if use_bn:
                    modules.append(nn.BatchNorm1d(channels[i + 1]))
                modules.append(nn.LeakyReLU(inplace=True))

            self.mlps.append(nn.Sequential(*modules))

    def forward(self, x):
        x_ = x
        for i, m in enumerate(self.mlps):
            if i != 0 and self.residual:
                x_ = torch.cat([x_, x], dim=1)
            x_ = m(x_)
        return x_

query_view_feats:应该是对应这个公式

def query_view_feats(view_feats, points_proj, fusion='max'):
    # view_feats: [B, M, C, H, W]
    # points_proj: [B, M, N, 2]
    # output: [B, C, N, M]
    n_view = view_feats.shape[1]
    p_feats_list = []
    for i in range(n_view):
        feat = view_feats[:, i, ...] # B, C, W, H
        p = points_proj[:, i, ...] # B, N, 2
        p_feats = index_2d(feat, p) # B, C, N
        p_feats_list.append(p_feats)
    p_feats = torch.stack(p_feats_list, dim=-1) # B, C, N, M
    if fusion == 'max':
        p_feats = F.max_pool2d(p_feats, (1, p_feats.shape[-1]))
        p_feats = p_feats.squeeze(-1) # [B, C, K]
    elif fusion is not None:
        raise NotImplementedError
    return p_feats

插值如下

下面有一个点decoder

class PointDecoder(nn.Module):
    def __init__(self, channels, residual=True, use_bn=True):
        super().__init__()

        self.residual = residual
        self.mlps = nn.ModuleList()

        for i in range(len(channels) - 1):
            modules = []
            if i == 0 or not self.residual:
                modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
            else:
                modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))

            if i != len(channels) - 1:
                if use_bn:
                    modules.append(nn.BatchNorm1d(channels[i + 1]))
                modules.append(nn.LeakyReLU(inplace=True))

            self.mlps.append(nn.Sequential(*modules))

    def forward(self, x):
        x_ = x
        for i, m in enumerate(self.mlps):
            if i != 0 and self.residual:
                x_ = torch.cat([x_, x], dim=1)
            x_ = m(x_)
        return x_

用了残差网络进行预测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1910284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于MySQL mvcc

innodb mvcc mvcc 多版本并发控制 在RR isolution 情况下 trx在启动的时候就拍了个快照。这个快照是基于整个数据库的。 其实这个快照并不是说拷贝整个数据库。并不是说要拷贝出这100个G的数据。 innodb里面每个trx有一个唯一的trxID 叫做trx id .在trx 开始的时候向innodb系…

录音的内容怎么做二维码?支持多种音频格式使用的制作技巧

怎么把录制的音频文件做成二维码呢?现在用二维码来存储内容是一种很常用的方式,让其他人扫描二维码来查看内容,从而提升内容传输的速度。比如现在很多人会将音频生成二维码,其他人可以通过扫码在手机上播放音频内容,那…

kafka的副本replica

指定topic的分区和副本 通过kafka命令行工具 kafka-topics.sh --create --topic myTopic --partitions 3 --replication-factor 1 --bootstrap-server localhost:9092 执行代码时指定分区个数

谈大语言模型动态思维流编排

尽管大语言模型已经呈现出了强大的威力,但是如何让它完美地完成一个大的问题,仍然是一个巨大的挑战。 需要精心地给予大模型许多的提示(Prompt)。对于一个复杂的应用场景,编写一套完整的,准确无误的提示&am…

JavaWeb__正则表达式

目录 1. 正则表达式简介2. 正则表达式体验2.1 验证2.2 匹配2.3 替换2.4 全文查找2.5 忽略大小写2.6 元字符使用2.7 字符集合的使用2.8 常用正则表达式 1. 正则表达式简介 正则表达式是描述字符模式的对象。正则表达式用于对字符串模式匹配及检索替换,是对字符串执行…

如何让 3D 数字孪生场景闪闪发光

今日图扑软件功能分享:我们将探讨 HT 系统如何通过分组管理灯光、裁切体和流光,以提高场景光影效果的精准度和整体可控性。 HT 中的灯光、裁切体、流光是会影响它所在区域一定范围内的其他节点的表现,如 场景中有个 A 灯光,默认情…

微信小程序引入自定义子组件报错,在 C:/Users/***/WeChatProjects/miniprogram-1/components/路径下***

使用原生小程序开发时候,会报下面的错误, [ pages/button/button.json 文件内容错误] pages/button/button.json: [“usingComponents”][“second-component”]: “…/…/components/second-child/index”,在 C:/Users/***/WeChatProjects/m…

布隆过滤器 redis

一.为什么要用到布隆过滤器? 缓存穿透:查询一条不存在的数据,缓存中没有,则每次请求都打到数据库中,导致数据库瞬时请求压力过大,多见于爬虫恶性攻击因为布隆过滤器是二进制的数组,如果使用了它…

小米手机短信怎么恢复?不用求人,3个技巧一网打尽

当你突然发现安卓手机里的重要短信不见了,是不是感到一阵心慌意乱?别急,不用求人,更不用焦虑。作为基本的社交功能,短信是我们与外界沟通的重要桥梁,当删除后,短信怎么恢复呢?今天&a…

Halcon 模糊圆边的找圆案例

Halcon 模糊圆边的找圆案例 基本思路 1.将图像转成灰度图像 2.再观察要找到的区域的灰度值变化,找到前景与背景的具体数值。 3.根据找到的前景与背景的具体数值,增强图像对比度。(使图像变成黑白图片) 4.使用灰度直图工具进行阈值…

ChatTTS使用

ChatTTS是一款适用于日常对话的生成式语音模型。 克隆仓库 git clone https://github.com/2noise/ChatTTS cd ChatTTS 使用 conda 安装 conda create -n chattts conda activate chattts pip install -r requirements.txt 安装完成后运行 下载模型并运行 python exampl…

android13 固定U盘链接 SD卡链接 TF卡链接 硬盘链接

1.前言 有些客户使用的应用并不带有自动监听U盘 sd卡广播的代码,使用的代码是固定的地址,这样的话,就需要我们将系统的挂载目录固定了。 原始路径 /storage/3123-19FA 增加链接 /storage/upan_000 -> /storage/3123-19FA 2. 首先如果是应用本身监听的话,使用的是 /…

美容美发在线预约小程序源码系统 前后端完整分离 带完整的安装代码包以及搭建教程

系统概述 在当今这个快节奏的社会,美容美发服务已经成为人们日常生活中不可或缺的一部分。为了满足广大消费者的便捷预约需求,以及美容美发行业的数字化转型趋势,一款高效、易用、功能全面的在线预约小程序显得尤为重要。今天,我…

纷享销客荣获CDIE“2024优秀数字化技术服务商”

近日,在第十届数字化创新博览会(CDIE 2024)上,CRM品牌领导者纷享销客凭借其卓越的技术实力和创新的解决方案,荣获“2024 优秀数字化技术服务商”奖项。 作为国内领先的CRM数字化解决方案服务商,纷享销客一直…

白盒测试的概念、特点、应用阶段、实施流程、现状与前景

文章目录 前言一、白盒测试的应用阶段二、白盒测试的特点三、白盒测试的流程四、白盒测试的现状与前景总结 前言 白盒测试(White Box Testing),又称为结构测试(Structural Testing)、透明盒测试(Glass Box…

循环练习 while

public static void main(String[] args) {double money100000;int count0;while(money>1000){if (money>50000){moneymoney-money*0.05;count;}else if (money>1000){money-1000;count;}else {break;}}System.out.println(count);} 结果为:

ggplot2绘图点的形状不够用怎么办?

群里有这么一个问题: 请问老师,fviz_pca_ind 做pca,当设置geom.ind “point”,group>6时,就不能显示第7,8组的点,应该如何处理(在不设置为文本的情况下),…

如何为IP申请SSL证书

目录 以下是如何轻松为IP地址申请SSL证书的详细步骤: 申请IP证书的基本条件: 申请IP SSL证书的方式: 确保网络通信安全的核心要素之一,是有效利用SSL证书来加密数据传输,特别是对于那些直接通过IP地址访问的资源。I…

部署Harbor镜像仓库并在k8s配置使用

文章目录 一、下载所需软件包1.docker-compose2.harbor 二、安装docker-compose1.安装docker2.配置docker-compose 三、安装harbor1.编辑harbor配置文件2.加载harbor配置(重新加载配置文件,只要修改配置文件就需要执行)3.开始安装harbor4.doc…

谷歌正在试行人脸识别办公室安全系统

内容提要: 🧿据美国消费者新闻与商业频道 CNBC 获悉,谷歌正在为其企业园区安全测试面部追踪技术。 🧿测试最初在华盛顿州柯克兰的一间办公室进行。 🧿一份内部文件称,谷歌的安全和弹性服务 (GSRS) 团队将…