这篇论文的标题是《Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction》,作者是Yiqun Lin, Hualiang Wang, Jixiang Chen和Xiaomeng Li,来自香港科技大学以及HKUST深圳-香港协同创新研究院。
这篇论文主要探讨了一种新的锥束计算机断层扫描(CBCT)重建框架,称为DIF-Gaussian,旨在通过使用更少的投影来减少辐射剂量,同时提高重建图像的质量。
给的代码只是个框架,强行复现花费时间而且以我水平容易误人子弟,就简单的对照论文理解一下,大家有兴趣可以一起讨论
项目地址:
GitHub - xmed-lab/DIF-Gaussian: MICCAI 2024: Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction
数据预处理地址
https://github.com/xmed-lab/C2RV-CBCT/tree/main/data
1、 下载代码和数据预处理方法,数据放到data中
2、发现代码是不完整的,因此边补充边写
train.py
使其与不同版本的DDP兼容
if args.dist:
args.local_rank = int(os.environ["LOCAL_RANK"]) # Make it compatible with different versions of DDP
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(args.local_rank)
加载cfg,项目只给出了一个default.yaml,复制一个改个名字
cfg = load_config(args.cfg_path)
if args.local_rank == 0:
print(args)
print(cfg)
# save config
save_dir = f'./logs/{args.name}'
os.makedirs(save_dir, exist_ok=True)
if os.path.exists(os.path.join(save_dir, 'config.yaml')):
time_str = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
shutil.copyfile(
os.path.join(save_dir, 'config.yaml'),
os.path.join(save_dir, f'config_{time_str}.yaml')
)
shutil.copyfile(args.cfg_path, os.path.join(save_dir, 'config.yaml'))
初始化训练数据集/加载器
train_dst = CBCT_dataset_gs(
dst_name=args.dst_name,
cfg=cfg.dataset,
split='train',
num_views=args.num_views,
npoint=args.num_points,
out_res_scale=args.out_res_scale,
random_views=args.random_views
)
关键在于并没有数据,因此还得自己想办法
dataset:
root_dir: ../../datasets
gs_res: 12 # the resolution of GS points (12^3 points in total)
进去看看数据集如何构建
class CBCT_dataset_gs(CBCT_dataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
gs_res = self.cfg.gs_res
points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res
self.points_gs = points_gs.reshape(3, -1).transpose(1, 0) # ~[0, 1]
def __getitem__(self, index):
data_dict = super().__getitem__(index)
# projections of GS points (initial center xyz)
points_gs = deepcopy(self.points_gs)
points_gs_proj = self.project_points(points_gs, data_dict['angles'])
data_dict.update({
'points_gs': points_gs, # [K, 3]
'points_gs_proj': points_gs_proj # [M, K, 2]
})
return data_dict
np.mgrid
是NumPy库中的一个函数,它返回一个由给定尺寸的数组创建的多维网格。这段代码points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res
创建了一个3D网格,并且将这个网格的每个点归一化到[0, 1]区间。结果
points_gs
是一个4D数组,其形状为(gs_res, gs_res, gs_res, 3)
,其中最后一个维度包含每个网格点的x、y、z坐标。
看getitem
points_gs_proj = self.project_points(points_gs, data_dict['angles'])
points_gs
是一个3D网格的点,通常是用于表示3D空间中的一个体素化网格或者用于定义3D空间中的高斯分布的中心点。而points_gs_proj
则是这些点在2D平面上的投影。
代码是不全的,后期再看看会不会更新
看LUNA16数据预处理的config 内有dataset的参数,其中的angle 为180
get返回一个3d 高斯网格,一个2d的投影
loader如下
train_sampler = None
if args.dist:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dst)
train_loader = DataLoader(
train_dst,
batch_size=args.batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
num_workers=0, # args.num_workers,
pin_memory=True,
worker_init_fn=worker_init_fn
)
# -- initialize evaluation dataset/loader
eval_loader = DataLoader(
CBCT_dataset_gs(
dst_name=args.dst_name,
cfg=cfg.dataset,
split='eval',
num_views=args.num_views,
out_res_scale=0.5, # low-res for faster evaluation,
),
batch_size=1,
shuffle=False,
pin_memory=True
)
加载模型,模型放到后面看
# -- initialize model
model = DIF_Gaussian(cfg.model)
if args.resume:
print(f'resume model from epoch {args.resume}')
ckpt = torch.load(
os.path.join(f'./logs/{args.name}/ep_{args.resume}.pth'),
map_location=torch.device('cpu')
)
model.load_state_dict(ckpt)
model = model.cuda()
if args.dist:
model = nn.parallel.DistributedDataParallel(
model,
find_unused_parameters=False,
device_ids=[args.local_rank]
)
优化器和优化器规划,损失只有一个MSE
# -- initialize optimizer, lr scheduler, and loss function
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr,
momentum=0.98,
weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=1,
gamma=np.power(args.lr_decay, 1 / args.epoch)
)
loss_func = nn.MSELoss()
开始训练
# -- training starts
for epoch in range(start_epoch, args.epoch + 1):
if args.dist:
train_loader.sampler.set_epoch(epoch)
loss_list = []
model.train()
optimizer.zero_grad()
一个epoch,外部看没有花里胡哨的损失,一个损失做到底
for k, item in enumerate(train_loader):
item = convert_cuda(item)
pred = model(item)
loss = loss_func(pred['points_pred'], item['points_gt'])
loss_list.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
评估和优化
if args.local_rank == 0:
if epoch % 10 == 0:
loss = np.mean(loss_list)
print('epoch: {}, loss: {:.4}'.format(epoch, loss))
if epoch % 100 == 0 or (epoch >= (args.epoch - 100) and epoch % 10 == 0):
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
torch.save(
model_state,
os.path.join(save_dir, f'ep_{epoch}.pth')
)
if epoch % 50 == 0 or (epoch >= (args.epoch - 100) and epoch % 20 == 0):
metrics, _ = eval_one_epoch(
model,
eval_loader,
args.eval_npoint,
ignore_msg=True,
)
msg = f' --- epoch {epoch}'
for dst_name in metrics.keys():
msg += f', {dst_name}'
met = metrics[dst_name]
for key, val in met.items():
msg += ', {}: {:.4}'.format(key, val)
print(msg)
if lr_scheduler is not None:
lr_scheduler.step()
model .py
看看初始化定义了什么
class DIF_Gaussian(Recon_base):
def __init__(self, cfg):
super().__init__(cfg)
def init(self):
self.init_encoder()
# gaussians-related modules
mid_ch = self.image_encoder.out_ch
ds_ch = self.image_encoder.ds_ch
self.gs_feats_mlp = MLP_1d([ds_ch, ds_ch // 4, mid_ch], use_bn=True, last_bn=True, last_act=False)
self.gs_params_mlp = MLP_1d([ds_ch, ds_ch // 4, 3 + 4 + 3], use_bn=True, last_bn=False, last_act=False) # 3d: offsets, 4d: rotation, 3d: scaling
self.gs_act = nn.LeakyReLU(inplace=True)
self.init_decoder(mid_ch * 2)
self.registered_point_keys = ['points', 'points_proj']
初始化编码器:self.init_encoder()
定义高斯特征和参数mlp:self.gs_feats_mlp;self.gs_params_mlp,选用线性激活self.gs_act
初始化解码器
虽然没写完全,但是不难想象编码器和解码器的都是unet里面的
看向里面的点forward ,获取点的预测值
1多视图像素对齐功能+最大池
2gaussian-based插值函数
3逐点地预测
class PointDecoder(nn.Module):
def __init__(self, channels, residual=True, use_bn=True):
super().__init__()
self.residual = residual
self.mlps = nn.ModuleList()
for i in range(len(channels) - 1):
modules = []
if i == 0 or not self.residual:
modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
else:
modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))
if i != len(channels) - 1:
if use_bn:
modules.append(nn.BatchNorm1d(channels[i + 1]))
modules.append(nn.LeakyReLU(inplace=True))
self.mlps.append(nn.Sequential(*modules))
def forward(self, x):
x_ = x
for i, m in enumerate(self.mlps):
if i != 0 and self.residual:
x_ = torch.cat([x_, x], dim=1)
x_ = m(x_)
return x_
query_view_feats:应该是对应这个公式
def query_view_feats(view_feats, points_proj, fusion='max'):
# view_feats: [B, M, C, H, W]
# points_proj: [B, M, N, 2]
# output: [B, C, N, M]
n_view = view_feats.shape[1]
p_feats_list = []
for i in range(n_view):
feat = view_feats[:, i, ...] # B, C, W, H
p = points_proj[:, i, ...] # B, N, 2
p_feats = index_2d(feat, p) # B, C, N
p_feats_list.append(p_feats)
p_feats = torch.stack(p_feats_list, dim=-1) # B, C, N, M
if fusion == 'max':
p_feats = F.max_pool2d(p_feats, (1, p_feats.shape[-1]))
p_feats = p_feats.squeeze(-1) # [B, C, K]
elif fusion is not None:
raise NotImplementedError
return p_feats
插值如下
下面有一个点decoder
class PointDecoder(nn.Module):
def __init__(self, channels, residual=True, use_bn=True):
super().__init__()
self.residual = residual
self.mlps = nn.ModuleList()
for i in range(len(channels) - 1):
modules = []
if i == 0 or not self.residual:
modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
else:
modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))
if i != len(channels) - 1:
if use_bn:
modules.append(nn.BatchNorm1d(channels[i + 1]))
modules.append(nn.LeakyReLU(inplace=True))
self.mlps.append(nn.Sequential(*modules))
def forward(self, x):
x_ = x
for i, m in enumerate(self.mlps):
if i != 0 and self.residual:
x_ = torch.cat([x_, x], dim=1)
x_ = m(x_)
return x_
用了残差网络进行预测