# This file is part of AIdsorb.
# Copyright (C) 2024 Antonios P. Sarikas
# AIdsorb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
r"""
This module provides deep learning architectures for point cloud processing.
.. note::
Currently, only :class:`PointNet` is implemented, a lightweight version of
the original architecture where the :class:`~aidsorb.modules.TNet`'s for
input and feature transforms have been removed.
.. warning::
It is recommended to **use batched inputs in all cases**. For example, even
if a single ``pcd`` of shape ``(3+C, N)`` is to be processed with
:class:`PointNet`, **reshape it to** ``(1, 3+C, N)``. One way to you can do it
is the following: ``pcd = pcd.unsqueeze(0)``.
.. todo::
Add more architectures for point cloud processing.
References
----------
.. [PointNet] R. Q. Charles, H. Su, M. Kaichun and L. J. Guibas, "PointNet: Deep
Learning on Point Sets for 3D Classification and Segmentation," 2017 IEEE
Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI,
USA, 2017, pp. 77-85, doi: 10.1109/CVPR.2017.16.
"""
import torch
from . modules import PointNetBackbone
[docs]
class PointNet(torch.nn.Module):
r"""
Vanilla version from the [PointNet]_ paper where :class:`TNet`'s have been
removed.
``PointNet`` takes as input a point cloud and produces one or more outputs.
*The type of the task is determined by* ``head``.
Currently implemented heads include:
1. :class:`.PointNetClsHead`: classification and regression
2. :class:`.PointNetSegHead`: segmentation
The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is
the batch size, ``C`` is the number of input channels and ``N`` is the
number of points in each point cloud.
.. tip::
You can define a ``custom_head`` head as a :class:`torch.nn.Module` and
pass it to ``head``.
If ``local_features=False``, the input to ``custom_head`` must have the
same shape as in :meth:`.PointNetClsHead.forward`.
Otherwise, the input to ``custom_head`` must have the same shape as in
:meth:`.PointNetSegHead.forward`.
Parameters
----------
head : :class:`torch.nn.Module`
local_feats : bool, default=False
n_global_feats : int, default=1024
See Also
--------
:class:`~aidsorb.modules.PointNetBackbone` :
For a description of ``local_feats`` and ``n_global_feats``.
Examples
--------
>>> from aidsorb.modules import PointNetClsHead, PointNetSegHead
>>> cls_head = PointNetClsHead(n_outputs=2)
>>> seg_head = PointNetSegHead(n_outputs=10)
>>> x = torch.randn(32, 4, 300)
>>> cls_net = PointNet(head=cls_head, n_global_feats=256)
>>> cls_net(x).shape
torch.Size([32, 2])
>>> cls_net.backbone(x)[1].shape # Critical indices.
torch.Size([32, 256])
>>> seg_net = PointNet(head=seg_head, n_global_feats=512, local_feats=True)
>>> seg_net(x).shape
torch.Size([32, 300, 10])
>>> seg_net.backbone(x)[1].shape # Critical indices.
torch.Size([32, 512])
"""
def __init__(self, head, local_feats=False, n_global_feats=1024):
super().__init__()
self.backbone = PointNetBackbone(local_feats, n_global_feats)
self.head = head
[docs]
def forward(self, x):
r"""
Run the forward pass.
Parameters
----------
x : tensor of shape (B, C, N)
Returns
-------
out : tensor
The output of ``head``.
"""
feats, _ = self.backbone(x) # Ignore critical indices.
out = self.head(feats)
return out