aidsorb.models
This module provides deep learning architectures for point cloud processing.
Note
Currently, only PointNet is implemented, a lightweight version of
the original architecture where the TNet’s for
input and feature transforms have been removed.
Warning
It is recommended to use batched inputs in all cases. For example, even
if a single pcd of shape (3+C, N) is to be processed with
PointNet, reshape it to (1, 3+C, N). One way to you can do it
is the following: pcd = pcd.unsqueeze(0).
Todo
Add more architectures for point cloud processing.
References
R. Q. Charles, H. Su, M. Kaichun and L. J. Guibas, “PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation,” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI, USA, 2017, pp. 77-85, doi: 10.1109/CVPR.2017.16.
- class aidsorb.models.PointNet(head, local_feats=False, n_global_feats=1024)[source]
Bases:
ModuleVanilla version from the [PointNet] paper where
TNet’s have been removed.PointNettakes as input a point cloud and produces one or more outputs. The type of the task is determined byhead.Currently implemented heads include:
PointNetClsHead: classification and regressionPointNetSegHead: segmentation
The input must be batched, i.e. have shape of
(B, C, N)whereBis the batch size,Cis the number of input channels andNis the number of points in each point cloud.Tip
You can define a
custom_headhead as atorch.nn.Moduleand pass it tohead.If
local_features=False, the input tocustom_headmust have the same shape as inPointNetClsHead.forward(). Otherwise, the input tocustom_headmust have the same shape as inPointNetSegHead.forward().- Parameters:
head (
torch.nn.Module)local_feats (bool, default=False)
n_global_feats (int, default=1024)
See also
PointNetBackboneFor a description of
local_featsandn_global_feats.
Examples
>>> from aidsorb.modules import PointNetClsHead, PointNetSegHead >>> cls_head = PointNetClsHead(n_outputs=2) >>> seg_head = PointNetSegHead(n_outputs=10) >>> x = torch.randn(32, 4, 300)
>>> cls_net = PointNet(head=cls_head, n_global_feats=256) >>> cls_net(x).shape torch.Size([32, 2]) >>> cls_net.backbone(x)[1].shape # Critical indices. torch.Size([32, 256])
>>> seg_net = PointNet(head=seg_head, n_global_feats=512, local_feats=True) >>> seg_net(x).shape torch.Size([32, 300, 10]) >>> seg_net.backbone(x)[1].shape # Critical indices. torch.Size([32, 512])