目录
作者自己实现了一版:
c代码:
FPS(Farthest Point Sampling)--最远点采样算法
报错代码:
报错解决方法:
作者自己实现了一版:
extension-cpp/cuda/setup.py at dbf58c505fe5035eccdb71195b6e49c6694f7bbc · hetong007/extension-cpp · GitHub
c代码:
https://github.com/Barbany/fps_cuda/tree/main
FPS(Farthest Point Sampling)--最远点采样算法
新版代码:
pointnet2_ops_lib/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py at main · ybhhdt/pointnet2_ops_lib · GitHub
class BallQuery(Function):
@staticmethod
def forward(ctx, radius, nsample, xyz, new_xyz):
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""
Parameters
----------
radius : float
radius of the balls
nsample : int
maximum number of features in the balls
xyz : torch.Tensor
(B, N, 3) xyz coordinates of the features
new_xyz : torch.Tensor
(B, npoint, 3) centers of the ball query
Returns
-------
torch.Tensor
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
idx, pts_cnt= _ext.ball_query(new_xyz, xyz, radius, nsample)
ctx.mark_non_differentiable(idx)
ctx.mark_non_differentiable(pts_cnt)
return idx, pts_cnt
报错代码:
import torch
from torch.autograd import Function
try:
import builtins
except:
import __builtin__ as builtins
try:
import fps_cuda._ext as _ext
except ImportError:
if not getattr(builtins, "__POINTNET2_SETUP__", False):
raise ImportError(
"Could not import _ext module.\n"
"Please see the setup instructions in the README: "
"https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst"
)
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, npoint):
# type: (Any, torch.Tensor, int) -> torch.Tensor
r"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
Parameters
----------
xyz : torch.Tensor
(B, N, 3) tensor where N > npoint
npoint : int32
number of features in the sampled set
Returns
-------
torch.Tensor
(B, npoint) tensor containing the set
"""
fps_inds = _ext.furthest_point_sampling(xyz, npoint)
ctx.mark_non_differentiable(fps_inds)
return fps_inds
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
xyz = torch.randn(5,1024,3)
npoint=512
fps_inds = furthest_point_sample(xyz, npoint)
# ctx.mark_non_differentiable(fps_inds)
print(fps_inds)
RuntimeError: Unknown layout in MultiScaleDeformableAttnFunction · Issue #284 · IDEA-Research/GroundingDINO · GitHub
报错解决方法:
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118