PointNet代码详解

news2025/1/19 20:19:09

PointNet代码详解

PointNet
这里只看下分类网络

input transform

class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x

class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

输入点云(32,3,2500)32个批次,3代表坐标,2500代表2500个点
经过STN3d输出的维度是 ([32, 3, 3])

提取点云特征

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

当global_feat为True时,提取全局特征;global为False时,提取点和全局拼接的特征。

PointNetCls

class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat

class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat
# 计算输入矩阵trans的特征变换正则化损失。
def feature_transform_regularizer(trans):
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[None, :, :]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss

测试

if __name__ == '__main__':
    sim_data = Variable(torch.rand(32,3,2500))
    trans = STN3d()
    out = trans(sim_data)
    print('stn', out.size())
    print('loss', feature_transform_regularizer(out))

    sim_data_64d = Variable(torch.rand(32, 64, 2500))
    trans = STNkd(k=64)
    out = trans(sim_data_64d)
    print('stn64d', out.size())
    print('loss', feature_transform_regularizer(out))

    pointfeat = PointNetfeat(global_feat=True)
    out, _, _ = pointfeat(sim_data)
    print('global feat', out.size())

    pointfeat = PointNetfeat(global_feat=False)
    out, _, _ = pointfeat(sim_data)
    print('point feat', out.size())

    cls = PointNetCls(k = 5)
    out, _, _ = cls(sim_data)
    print('class', out.size())

    seg = PointNetDenseCls(k = 3)
    out, _, _ = seg(sim_data)
    print('seg', out.size())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1280375.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

人工智能轨道交通行业周刊-第67期(2023.11.27-12.3)

本期关键词:列车巡检机器人、城轨智慧管控、制动梁、断路器、AICC大会、Qwen-72B 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro…

C++ day51 买卖股票最佳时期

题目1:309 买卖股票的最佳时机含冷冻期 题目链接:买卖股票的最佳时机含冷冻期 对题目的理解 prices[i]表示第i天股票的价格,尽可能多地完成更多的交易,不能同时进行多笔交易,卖出股票后,第二天无法买入股…

java后端自学错误总结

java后端自学错误总结 MessageSource国际化接口总结 MessageSource国际化接口 今天第一次使用MessageSource接口,比较意外遇到了一些坑 messageSource是spring中的转换消息接口,提供了国际化信息的能力。MessageSource用于解析 消息,并支持消息的参数化…

Spring Initial 脚手架国内镜像地址

官方的脚手架下载太慢了,并且现在没有了Java8的选项,所以找到国内的脚手架镜像地址,推荐给大家。 首先说官方的脚手架 官方的脚手架地址为: https://start.spring.io/ 但是可以看到,并没有了Java8的选项。 所以推荐…

如何通过K线发现短线机会?

一、K线的含义 股票一天之内有4个最关键的价格,开盘价、收盘价、最高价和最低价,把这个价格显示在图上就是K线图。 以金斗云智投电脑版为例,打开软件,任意搜索一支个股,就可以看到这支股票的K线。 股市新手看到这儿多…

【像素画板】游戏地图编辑器-uniapp项目开发流程详解

嘿,用过像素画板没有哦,相信喜欢绘画的小朋友会对它感兴趣呢,用来绘制像素画非常好看,有没有发现,它是可以用来绘制游戏地图的,是不是很好奇,来一起看看吧。 像素画板,也叫像素画的绘…

DeDeCMS v5.7 SP2 正式版 前台任意用户密码修改(漏洞复现)

1.环境搭建 PHP 5.6 DeDeCMSV5.7SP2 正式版 安装phpstudy,https://www.xp.cn/小皮面板 先启动Apache2.4.39和MySQL5.7.26 如果他会让你下载,点击是就好! 让后点击网站—>点击创建网站 域名自己创建,自己取 其他的不变 点击…

iOS 自动签名打包,并用脚本上传appstore

背景: 1)测试环境给测试,产品,或者其他业务人员打测试包时,经常存在需要添加设备,不得不重新生成描述文件,手动去更新打包机描述文件配置 2)证书,描述文件过期造成打包失…

Session 与 JWT 的对决:谁是身份验证的王者? (上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

树与二叉树堆:经典OJ题集(2)

目录 二叉树的性质及其问题: 二叉树的性质 问题: 一、对称的二叉树: 题目: 解题思路: 二、另一棵树: 题目: 解题思路: 三、翻转二叉树: 题目:…

windows如何配置java环境变量(java环境变量配置教程)

本文章以Windows为例记录Java环境变量配置详情。 笔者系统为windows10, JDK百度云链接放在文末有需要的可以下载。 当我们下载并安装好JDK后需要配置环境变量,否则无法方便的使用JDK中的工具进行Java源文件的编译及运行,或者其他工具的使用…

numpy知识库:基于numpy绘制灰度直方图

前言 对于灰度图像而言,灰度直方图可以统计灰度图像内各个灰度级出现的次数。 灰度直方图的横坐标是灰度图像中各像素点的灰度级。灰度的数值范围为[0, 255]。因此,如果将图像分为256个灰度级,那么每个灰度级唯一对应一个灰度;如…

分享一个大学生免费的资源网站(含考研资源,竞赛四六级)

今天不小心从其他地方链接到的网站,里面包含考考研资料,四六级相关的资料,重点都是免费的,部分资料可能需要登录或者关注公众号才可见,,网站链接了CSDN 能跳转到CSND, 网站地址 :忠哥资源共享http://jian…

奇技淫巧第9期

今天回顾一下 5~12 月所遇到的零碎知识点。 文章目录 歪门邪道优雅删除“学习资料”快速下载 vscode两种硬盘格式zotero在word中插入参考文献markdown 下划线查看 CPU Linux 命令postgres 无法通过 root 用户操作bash 初学者礼包gitwin 11 edge 浏览器0x80190001 报错 python …

inux基础项目开发1:量产工具——业务系统(七)

前言: 前面我们已经构造出来显示系统、输入系统、文字系统、UI系统、页面系统,这个项目百分之八十需要实现的都已经构建出来了,最后让我们对这个项目进行最后一项系统的搭建,也就是业务系统,说到业务大家应该就知道我们…

软件生命周期四个阶段SDLC

软件产品生命周期:指软件产品研发全部过程、活动和任务的结构框架。 产品的生命周期一般包括四个阶段:引入期、成长期、成熟期和衰退期,在不同的阶段中,市场对产品的反应不同,其销售特点不同,因而产品管理的…

Windows驱动中使用数字签名验证控制设备访问权限

1. 背景 在一般的驱动开发时,创建了符号链接后在应用层就可以访问打开我们的设备并进行通讯。 但我们有时候不希望非自己的进程访问我们的设备并进行交互,虽然可以使用 IoCreateDeviceSecure 来创建有安全描述符的设备,但大数的用户账户为了方…

【理解ARM架构】中断处理 | CPU模式

🐱作者:一只大喵咪1201 🐱专栏:《理解ARM架构》 🔥格言:你只管努力,剩下的交给时间! 目录 🍜中断🍨GPIO中断代码实现 🍜CPU🍨CONTROL…

设计模式-结构型模式之装饰者设计模式

文章目录 六、装饰者模式 六、装饰者模式 装饰者模式(Decorator Pattern)允许向一个现有的对象添加新的功能,同时又不改变其结构。它是作为现有的类的一个包装。 装饰类和被装饰类可以独立发展,不会相互耦合,装饰者模…

各大期刊网址

AAAL: http://dblp.uni-trier.de/db/conf/aaai/ CVPR: http://dblp.uni-trier.de/db/conf/cvpr/ NeurlPS:http://dblp.uni-trier.de/db/conf/nips/ ICCV: http://dblp.uni-trier.de/db/conf/iccv/ IJCAL: http://dblp.uni-trier.de/db/conf/ijcal/ 并非原创引…