图神经网络:(图像分割)三维网格图像分割

news2025/1/17 5:54:53

文章说明:
1)参考资料:PYG的文档。文档超链。斯坦福大学的机器学习课程。课程超链。(要挂梯子)。博客原文。原文超链。(要挂梯子)。原文理论参考文献。提取码8848。
2)我在百度网盘上传这篇文章的jupyter notebook以及预训练模型。提取码8848.
3)博主水平不高,如有错误,还望批评指正
一些建议:注重理论建议直接去看文献;注重实践建议直接去看代码。他的代码会有详细注释,但实际没啥用,如果不看原文参考文献。建议手敲一遍代码,会对理解很有帮助。变量名字取得很好,如果有图神经基础,不看文献也是可以。

文章目录

  • 前言1:硬件问题
  • 前言2:有关综述
  • 数据描述
  • 数据下载
  • 任务描述
  • 代码演示

前言1:硬件问题

如果电脑不是很好,并不建议自己训练。我的电脑不是很好,训练大概有20分钟。最后电脑特别的烫,感觉对电脑很不好。我的电脑配置如下(应该是看这个,对于硬件我不清楚)。直接下载预训练的模型就好。在这里插入图片描述

前言2:有关综述

对于一般图像分割以及图像分类任务,卷积神经网络取得巨大成功。但是卷积神经网络不能处理不规则的数据结构。我们希望推广卷积神经网络到不规则数据结构。卷积神经网络博主不很了解,不所以作过多评价。图神经网络为解决问题,应孕而生。我们使用3D点云进行演示。

数据描述

我们使用两个矩阵表示数据:十分简单,看图易懂。图片自源博客。我们需要一个矩阵存储n个点的位置。我们需要一个矩阵存储点间的边关系(3点确定一个平面,这就解释为什么是3个点了)。
在这里插入图片描述

数据下载

超链。

任务描述

正如标题:一个简单分类任务。我们需要对3D点云进行分类。头部点云,躯干点云,左臂点云,左手点云,右臂点云,右手点云,左大腿点云,左小腿点云,左脚点云,右大腿点云,右小腿点云,右脚点云。

代码演示

import torch
device='cuda' if torch.cuda.is_available() else 'cpu'

路径有关注意事项1:下载数据之后不要进行解压,放在一个文件之中就可以了。
路径有关注意事项2:复制文件地址需要进行修改,可能这跟操作系统有关但是我不清楚,我就只说我的。直接复制是这样"C:\Users\19216\Desktop\project\3DImage_Classification_And_Segmentation",我们需要更改所有"\“变为”/"。

root="C:/Users/19216/Desktop/project/3DImage_Classification_And_Segmentation"

以下定义数据变换。

from torch_geometric.transforms import BaseTransform
from torch_geometric.data import Data
#BaseTransform的构造十分简单,建议自己去看源码
class NormalizeUnitSphere(BaseTransform):
	#静态方法,不依赖类(加了这个应该就不用加self了)
    @staticmethod
    def _re_center(x):
        centroid=torch.mean(x,dim=0)
        return x-centroid
    @staticmethod
    def _re_scale_to_unit_length(x):
        max_dist=torch.max(torch.norm(x,dim=1))
        return x/max_dist
    #类的默认调用方法
    def __call__(self,data:Data):
        if data.x is not None:
            data.x=self._re_scale_to_unit_length(self._re_center(data.x))
        return data
    #就是打印类的名字
    def __repr__(self):
        return "{}()".format(self.__class__.__name__)
from torch_geometric.transforms import Compose,FaceToEdge
pre_transform=Compose([FaceToEdge(remove_faces=False),NormalizeUnitSphere()])

以下加载变换数据。

from pathlib import Path
import trimesh
def load_mesh(mesh_filename:Path):
    mesh=trimesh.load_mesh(mesh_filename,process=False)
    vertices=torch.from_numpy(mesh.vertices).to(torch.float)
    faces=torch.from_numpy(mesh.faces).t().to(torch.long).contiguous()
    return vertices,faces
from torch_geometric.data import InMemoryDataset,extract_zip
from functools import lru_cache
import numpy as np

关于这部分的代码,必须看这,看了你就知道了吧。这里代码逻辑是挺有意思的,由于篇幅原因读者自行研究。我来讲下逻辑,不一定正确哈。首先train_data申请调用SegmentationFaust。父类立马开始调用四个方法(如果没有直接跳过) raw_file_names(),processed_file_names(),download(),process()。具体到这里就只有processed_file_names()、process()。父类发现文件夹中没有processed_file_names()的对应文件,立即用process()处理数据生成processed_file_names()的对应文件。然后赋值[“training.pt”,“test.pt”]给self.processed_paths。最后子类开始运作读取数据并且赋值。所有数据在第一步处理好了。

class SegmentationFaust(InMemoryDataset):
    map_seg_label_to_id=dict(head=0,torso=1,left_arm=2,left_hand=3,
                             right_arm=4,right_hand=5,left_upper_leg=6,left_lower_leg=7,
                             left_foot=8,right_upper_leg=9,right_lower_leg=10,right_foot=11)           
    def __init__(self,root,train:bool=True,pre_transform=None):
        super().__init__(root,pre_transform)
        path=self.processed_paths[0] if train else self.processed_paths[1]
        self.data,self.slices=torch.load(path)
    #将方法转换为属性
    @property
    def processed_file_names(self)->list:
        return ["training.pt","test.pt"]
    @property
    #结果缓存,提高效率
    @lru_cache(maxsize=32)
    def _segmentation_labels(self):
        path_to_labels=Path(self.root)/"MPI-FAUST"/"segmentations.npz"
        seg_labels=np.load(str(path_to_labels))["segmentation_labels"]
        return torch.from_numpy(seg_labels).type(torch.int64)
    def _mesh_filenames(self):
        path_to_meshes=Path(self.root)/"MPI-FAUST"/"meshes"
        #正则匹配
        return path_to_meshes.glob("*.ply")
    def _unzip_dataset(self):
        path_to_zip=Path(self.root)/"MPI-FAUST.zip"
        extract_zip(str(path_to_zip),self.root,log=False)
    def process(self):
        self._unzip_dataset()
        data_list=[]
        for mesh_filename in sorted(self._mesh_filenames()):
            vertices, faces=load_mesh(mesh_filename)
            data=Data(x=vertices, face=faces)
            data.segmentation_labels=self._segmentation_labels
            if self.pre_transform is not None:
                data=self.pre_transform(data)
            data_list.append(data)
        torch.save(self.collate(data_list[:80]),self.processed_paths[0])
        torch.save(self.collate(data_list[80:]),self.processed_paths[1])
train_data=SegmentationFaust(root=root,pre_transform=pre_transform)
#输出:
#Processing...
#Done!
test_data=SegmentationFaust(root=root,train=False,pre_transform=pre_transform)
from torch_geometric.loader import DataLoader
train_loader=DataLoader(train_data,shuffle=True)
test_loader=DataLoader(test_data,shuffle=False)
from itertools import tee

这段代码特别抽象,读者自行理解研究(我的意思语法抽象不指代码逻辑)

def pairwise(iterable):
    a,b=tee(iterable)
    next(b,None)
    return zip(a,b)
import torch.nn as nn

这段代码同样抽象,读者自行理解研究(我的意思语法抽象不指代码逻辑)

def get_mlp_layers(channels:list,activation,output_activation=nn.Identity):
    layers=[]
    *intermediate_layer_definitions,final_layer_definition=pairwise(channels)
    for in_ch,out_ch in intermediate_layer_definitions:
        intermediate_layer=nn.Linear(in_ch,out_ch)
        layers+=[intermediate_layer,activation()]
    layers+=[nn.Linear(*final_layer_definition),output_activation()]
    return nn.Sequential(*layers)
from torch_geometric.nn import MessagePassing
def get_conv_layers(channels:list,conv:MessagePassing,conv_params:dict):
    conv_layers=[conv(in_ch,out_ch,**conv_params) for in_ch,out_ch in pairwise(channels)]
    return conv_layers
from torch_geometric.utils import add_self_loops,remove_self_loops
import torch.nn.functional as F

最后介绍参考论文,这里暂时放下不表
以下部分均为模型建立

class FeatureSteeredConvolution(MessagePassing):
    def __init__(self,in_channels:int,out_channels:int,num_heads:int,ensure_trans_invar:bool=True,bias:bool=True,with_self_loops:bool=True):
        super().__init__(aggr="mean")
        self.in_channels=in_channels;self.out_channels=out_channels;self.num_heads=num_heads;self.with_self_loops=with_self_loops
        self.linear=torch.nn.Linear(in_features=in_channels,out_features=out_channels*num_heads,bias=False)
        self.u=torch.nn.Linear(in_features=in_channels,out_features=num_heads,bias=False)
        self.c=torch.nn.Parameter(torch.Tensor(num_heads))
        if not ensure_trans_invar:
            self.v=torch.nn.Linear(in_features=in_channels,out_features=num_heads,bias=False)
        else:
            self.register_parameter("v",None)
        if bias:
            self.bias=torch.nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias",None)
        self.reset_parameters()
    def reset_parameters(self):
        torch.nn.init.uniform_(self.linear.weight)
        torch.nn.init.uniform_(self.u.weight)
        torch.nn.init.normal_(self.c,mean=0.0,std=0.1)
        if self.v is not None:
            torch.nn.init.uniform_(self.v.weight)
        if self.bias is not None:
            torch.nn.init.normal_(self.bias,mean=0.0,std=0.1)
    def forward(self,x,edge_index):
        if self.with_self_loops:
            edge_index,_=remove_self_loops(edge_index)
            edge_index,_=add_self_loops(edge_index=edge_index,num_nodes=x.shape[0])
        out=self.propagate(edge_index,x=x)
        return out if self.bias is None else out+self.bias
    def _compute_attention_weights(self,x_i,x_j):
        if x_j.shape[-1]!=self.in_channels:
            raise ValueError(
                f"Expected input features with {self.in_channels} channels."
                f"Instead received features with {x_j.shape[-1]} channels."
            )
        if self.v is None:
            attention_logits=self.u(x_i-x_j)+self.c
        else:
            attention_logits=self.u(x_i)+self.b(x_j)+self.c
        return F.softmax(attention_logits,dim=1)
    def message(self,x_i,x_j):
        attention_weights=self._compute_attention_weights(x_i,x_j)
        x_j=self.linear(x_j).view(-1,self.num_heads,self.out_channels)
        return (attention_weights.view(-1,self.num_heads,1)*x_j).sum(dim=1)
class GraphFeatureEncoder(torch.nn.Module):
    def __init__(self,in_features,conv_channels,num_heads,apply_batch_norm:int=True,ensure_trans_invar:bool=True,bias:bool=True,with_self_loops:bool=True):
        super().__init__()
        self.apply_batch_norm=apply_batch_norm;conv_params=dict(num_heads=num_heads,ensure_trans_invar=ensure_trans_invar,bias=bias,with_self_loops=with_self_loops)
        conv_layers=get_conv_layers(channels=[in_features]+conv_channels,conv=FeatureSteeredConvolution,conv_params=conv_params)
        self.conv_layers=nn.ModuleList(conv_layers)
        *first_conv_channels,final_conv_channel=conv_channels
        self.batch_layers=[None for _ in first_conv_channels]
        if apply_batch_norm:
            self.batch_layers=nn.ModuleList([nn.BatchNorm1d(channel) for channel in first_conv_channels])
    def forward(self,x,edge_index):
        *first_conv_layers,final_conv_layer=self.conv_layers
        for conv_layer,batch_layer in zip(first_conv_layers,self.batch_layers):
            x=conv_layer(x,edge_index)
            x=F.relu(x)
            if batch_layer is not None:
                x=batch_layer(x)
        return final_conv_layer(x,edge_index)
class MeshSeg(torch.nn.Module):
    def __init__(self,in_features,encoder_features,conv_channels,encoder_channels,decoder_channels,num_heads,num_classes,apply_batch_norm=True):
        super().__init__()
        self.input_encoder=get_mlp_layers(channels=[in_features]+encoder_channels,activation=nn.ReLU)
        self.gnn=GraphFeatureEncoder(in_features=encoder_features,conv_channels=conv_channels,num_heads=num_heads,apply_batch_norm=apply_batch_norm)
        *_,final_conv_channel=conv_channels
        self.final_projection=get_mlp_layers([final_conv_channel]+decoder_channels+[num_classes],activation=nn.ReLU)
    def forward(self,data):
        x,edge_index=data.x,data.edge_index
        x=self.input_encoder(x)
        x=self.gnn(x,edge_index)
        return self.final_projection(x)

设定参数

model_params=dict(in_features=3,encoder_features=16,conv_channels=[32,64,128,64],encoder_channels=[16],decoder_channels=[32],num_heads=12,num_classes=12,apply_batch_norm=True)
net=MeshSeg(**model_params).to(device)
best_test_acc=0.0;num_epochs=50;lr=0.001;optimizer=torch.optim.Adam(net.parameters(),lr=lr);loss_fn=torch.nn.CrossEntropyLoss()

开始训练

def train(net,train_data,optimizer,loss_fn,device):
    net.train()
    cumulative_loss=0.0
    for data in train_data:
        data=data.to(device)
        optimizer.zero_grad()
        out=net(data)
        loss=loss_fn(out,data.segmentation_labels.squeeze())
        loss.backward()
        cumulative_loss+=loss.item()
        optimizer.step()
    return cumulative_loss/len(train_data)
def accuracy(predictions,gt_seg_labels):
    predicted_seg_labels=predictions.argmax(dim=-1,keepdim=True)
    if predicted_seg_labels.shape!=gt_seg_labels.shape:
        raise ValueError("Expected Shapes to be equivalent")
    correct_assignments=(predicted_seg_labels==gt_seg_labels).sum()
    num_assignemnts=predicted_seg_labels.shape[0]
    return float(correct_assignments/num_assignemnts)
def evaluate_performance(dataset,net,device):
    prediction_accuracies=[]
    for data in dataset:
        data=data.to(device)
        predictions=net(data)
        prediction_accuracies.append(accuracy(predictions,data.segmentation_labels))
    return sum(prediction_accuracies)/len(prediction_accuracies)
@torch.no_grad()
def test(net,train_data,test_data,device):
    net.eval()
    train_acc=evaluate_performance(train_data,net,device)
    test_acc=evaluate_performance(test_data,net,device)
    return train_acc,test_acc
from tqdm import tqdm
with tqdm(range(num_epochs),unit="Epoch") as tepochs:
    for epoch in tepochs:
        train_loss=train(net,train_loader,optimizer,loss_fn,device)
        train_acc,test_acc=test(net,train_loader,test_loader,device)
        tepochs.set_postfix(train_loss=train_loss,train_accuracy=100*train_acc,test_accuracy=100*test_acc)
        if test_acc>best_test_acc:
            best_test_acc=test_acc
            torch.save(net.state_dict(),root+"/checkpoint_best_colab")

开始画图

def load_model(model_params,path_to_checkpoint,device):
    try:
        model=MeshSeg(**model_params)
        model.load_state_dict(torch.load(str(path_to_checkpoint)),strict=True)
        model.to(device)
        return model
    except RuntimeError as err_msg:
        raise ValueError(
            f"Given checkpoint {str(path_to_checkpoint)} could not be loaded. {err_msg}"
        )
def get_best_model(model_params,device):
    path_to_trained_model=Path(root+"/checkpoint_best_colab")
    trained_model=load_model(model_params,path_to_trained_model,device)
    return trained_model
net=get_best_model(model_params,device)
segmentation_colors=dict(head=torch.tensor([255,255,255],dtype=torch.int),torso=torch.tensor([255,255,128],dtype=torch.int),
                         left_arm=torch.tensor([255,255,0],dtype=torch.int),left_hand=torch.tensor([255,128,255],dtype=torch.int),
                         right_arm=torch.tensor([255,128,128],dtype=torch.int),right_hand=torch.tensor([255,128,0],dtype=torch.int),
                         left_upper_leg=torch.tensor([255,0,255],dtype=torch.int),left_lower_leg =torch.tensor([255,0,128],dtype=torch.int),
                         left_foot=torch.tensor([255,0,0],dtype=torch.int),right_upper_leg=torch.tensor([128,255,255],dtype=torch.int),
                         right_lower_leg=torch.tensor([128,255,128],dtype=torch.int),right_foot=torch.tensor([128,255,0],dtype=torch.int)
)
map_seg_id_to_color=dict((_value,segmentation_colors[_key]) for _key,_value in train_data.map_seg_label_to_id.items())
@torch.no_grad()
def visualize_prediction(net,data,device,map_seg_id_to_color):
    def _map_seg_label_to_color(seg_ids,map_seg_id_to_color):
        return torch.vstack([map_seg_id_to_color[int(seg_ids[idx])] for idx in range(seg_ids.shape[0])])
    data=data.to(device)
    predictions=net(data)
    predicted_seg_labels=predictions.argmax(dim=-1,keepdim=True)
    mesh_colors=_map_seg_label_to_color(predicted_seg_labels,map_seg_id_to_color)
    segmented_mesh=trimesh.base.Trimesh(vertices=data.x.cpu().numpy(),faces=data.face.t().cpu().numpy(),process=False)
    segmented_mesh.visual.vertex_colors=mesh_colors.cpu().numpy()
    return segmented_mesh
segmented_meshes=[]
mesh_ids=[0,1,2,3,4,5,6,7,8,9]
for idx,mesh_id in enumerate(mesh_ids):
    segmented_mesh=visualize_prediction(net,test_data[mesh_id],device,map_seg_id_to_color)
    segmented_mesh.vertices+=[idx*1.0,0.0,0.0]
    segmented_meshes.append(segmented_mesh)
scene=trimesh.scene.Scene(segmented_meshes)
scene.show()

在这里插入图片描述
论文部分不想写了。以后再来吧,那就这样吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/731838.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

qt信号与槽

信号与槽的概念: 1>信号:信号就是信号函数,可以是组件自身提供,也可以是用户自己定义,自定义时,需要类体的signals权限下进行定义,该函数是一个不完整的函数,只有声明&#xff0…

输入一个链表,输出该链表的倒数第 k 的结点

一、思路 假设 K 是 2,根据下面的图片可以看出,倒数第 K 个结点就是 45。 需要注意的前提是,K 不能是负数也不能是 0 并且也不能超过链表的结点个数,因为要保证 K 是在链表的范围里,才能找到 K,然后返回这…

【网络】TCP三次握手和四次挥手(感性理解)

目录 三次握手 文字描述三次握手过程 为什么是三次握手? 什么是SYN洪水? 连接和半连接队列 一次、两次握手行不行,四/五/六次握手行不行? 三次握手一定会成功吗? 三次握手的过程中可不可以携带数据 TCP中的IS…

模块化规范

常用模块化有两种规范,commonJS和ES6 一:两者区别 二:如何转义? 我们常遇到的使用场景是,在commonJS的模块里需要引入ES6规范的模块。这时就需要把ES6模块转译为commonJS规范的模块,否则报错 转义工具有…

javassist 02 implement interface

创建 interface package com.wsd;public interface AccountDao {int delete(); }利用 javassist 生产一个 类A, Class A implements AccountDao package com.wsd;import javassist.ClassPool; import javassist.CtClass; import javassist.CtMethod; import javassist.Modifi…

mac桌面时钟 浮动 (python)

浮动时钟,多地时区 app store的都要钱,于是。。。。我们让chatgpt来实现一个吧: 数字: 代码: import sys import datetime import pytzfrom PyQt5.QtWidgets import QApplication, QMainWindow, QGraphicsView, QGr…

深度学习不同数据增广方法的选用分析

一般情况下,可以将数据扩增方法分为单数据变形、多数据混合、学习数据分布规律生成新数据和学习增广策略等4 类方法。以上顺序也在一定程度上反映了数据增广方法的发展历程。如果与Shorten和Khoshgoftaar的成果对照,就图像数据而言,基于数据变…

抖音矩阵源码搭建开发技术部署分析

目录 一、 什么是抖音矩阵?源码搭建开发注意事项? 1. 抖音矩阵概述 2. 源码搭建开发注意事项: 二、 使用步骤及开发代码展示 一、 什么是抖音矩阵?源码搭建开发注意事项? 1. 抖音矩阵概述 首先,抖音账…

21夜间车牌识别(matlab程序)

1.简述 简单说一下实现思路: 读取图片,转灰度,计算灰度直方图,估算阈值(这里的阈值计算很重要,经过阈值算法,选取一个最恰当的阈值),之后二值化。显示图像即可。 实现目…

爬虫爬取公众号文章

前言 自从chatGPT出现后,对于文本处理的能力直接上升了一个维度。在这之前,我们爬取到网络上的文本内容之后,都需要写一个文本清理的程序,对文本进行清洗,而现在,有了chatGPT的加持,我们只需要…

解决程序占用较多内存的问题

今天发现自己开发的一个程序占用了大量内存而且不会自动释放 ,我的程序在windows中运行的,解决办法如下: 第一步:打开任务管理器,打到正在运行程序 (这里以sql server为例),然后右击…

设计合并排序算法实现对N个整数排序。

1.题目 设计合并排序算法实现对N个整数排序 2.设计思路 先将无序序列利用分治法划分为子序列,直至每个子序列只有一个元素,然后再对有序子序列逐步进行合并排序。合并方法是循环的将两个有序子序列当前的首元素进行比较,较小的元素取出,置入合并序列的左边空置位,直至其中…

特征选择算法 | Matlab 基于最大相关最小冗余特征选择算法(mRMR)的分类数据特征选择

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 特征选择算法 | Matlab 基于最大相关最小冗余特征选择算法(mRMR)的分类数据特征选择 部分源码 %--------------------

Redis实战案例12-添加秒杀券实现秒杀下单及相关问题解决

1. 添加优惠券 该项目没有后台管理的界面,所以采用postman发送请求 http://localhost:8081/voucher/seckill注意end时间要大于当前系统时间 {"shopId": 2,"title": "100元代金券","subTitle": "周一至周五均可使用&qu…

c++查漏补缺

c语言的struct只能包含变量,而c中的class除了包含变量,还可以包含函数。 通过结构体定义出来的变量还是变量,而通过类定义出来有了新的名称,叫做对象。C语言中,会将重复使用或具有某项功能的代码封装成一个函数&#x…

【剑指offer】8. 斐波那契数列(java)

文章目录 斐波那契数列描述输入描述:返回值描述:示例1示例2示例3思路非递归递归 完整代码 斐波那契数列 描述 大家都知道斐波那契数列,现在要求输入一个正整数 n ,请你输出斐波那契数列的第 n 项。 斐波那契数列是一个满足 f …

PHP学生工作平台管理系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP学生工作平台管理系统 是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为PHP APACHE,数据 库为mysql5.0,使用php语言开发…

linux 如何挂载fat32格式u盘,如何挂载NTFS 文件系统的硬盘

linux系统默认可以识别fat32u盘,对ntfs格式u盘不能识别 具体挂载方式如下 1、插入u盘 2、mkdir /mnt/usb 此命令用于创建挂载u盘的目录,只需创建一次就可以,若已经存在则不需要再次创建 3、fdisk -l 找到u盘路径 上图显示的sdb1,sdb2,sdb5…

Gradio,我们可以为我们的模型创建Web界面

Gradio是一个Python库,允许我们快速为机器学习模型创建可定制的接口。 使用Gradio,我们可以为我们的模型创建Web界面,而无需编写任何HTML,CSS或JavaScript。 Gradio旨在与广泛的机器学习框架配合使用,包括TensorFlow&a…

IOU发展历程学习记录

概述 IOU的出现主要最先运用在预测bbox框和target bbox框之间的重叠问题,为NMS提供相应的数值支撑。另外在bbox框的回归问题上,由于L1 Loss存在如下问题:当损失函数对x的导数为常数,在训练后期,x很小时,若…