# This file is part of AIdsorb.
# Copyright (C) 2024 Antonios P. Sarikas
# AIdsorb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
r"""
:class:`torch.nn.Module`'s for point cloud processing.
.. note::
:class:`PointNetBackbone`, :class:`PointNetClsHead` and
:class:`PointNetSegHead` have their initial layers *lazy initialized*, so
you don't need to specify the input dimensionality.
.. warning::
It is recommended to **use batched inputs in all cases**. For example, even
if a single ``pcd`` of shape ``(3+C, N)`` is to be passed to
:class:`PointNetBackbone`, **reshape it to** ``(1, 3+C, N)``. One way you
can do it is the following: ``pcd = pcd.unsqueeze(0)``.
.. todo::
Add more modules for point cloud processing.
References
----------
.. [PointNet] Charles, R. Q., Su, H., Kaichun, M., & Guibas, L. J. (2017).
PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation. In
2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 77–85).
IEEE. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
https://doi.org/10.1109/cvpr.2017.16
"""
from typing import Any
import torch
from torch import Tensor
from torch import nn
from torch.nn import Module, Sequential
from .._torch_utils import get_activation
[docs]
def conv1d_block(
in_channels: int,
out_channels: int,
config_activation: dict[str, str | dict] | None = None,
**kwargs
) -> Sequential:
r"""
Return a 1D convolutional block.
The block has the following form::
block = nn.Sequential(
conv_layer,
nn.BatchNorm1d(out_channels),
activation_fn
)
Parameters
----------
in_channels : int or None
If :obj:`None`, the ``conv_layer`` is lazy initialized.
out_channels : int
config_activation : dict, default=None
Dictionary for configuring activation function. If :obj:`None`, the
:class:`~torch.nn.modules.activation.ReLU` activation is used.
* ``'name'`` activation's class name :class:`str`
* ``'hparams'`` activation's hyperparameters :class:`dict`
**kwargs
Valid keyword arguments for :class:`~torch.nn.Conv1d`.
Returns
-------
block : torch.nn.Sequential
Examples
--------
>>> inp, out = 4, 128
>>> x = torch.randn(32, 4, 100) # Shape (B, in_channels, N).
>>> config_afn = {'name': 'LeakyReLU', 'hparams': {'negative_slope': 0.5}}
>>> # Default activation function (ReLU).
>>> block = conv1d_block(inp, out, kernel_size=1)
>>> block(x).shape
torch.Size([32, 128, 100])
>>> block[2]
ReLU()
>>> # Custom activation function.
>>> block = conv1d_block(inp, out, config_afn, kernel_size=1)
>>> block(x).shape
torch.Size([32, 128, 100])
>>> block[2]
LeakyReLU(negative_slope=0.5)
>>> # Lazy initialized.
>>> block = conv1d_block(None, out, kernel_size=1)
>>> block(x).shape
torch.Size([32, 128, 100])
"""
if in_channels is not None:
conv_layer = nn.Conv1d(in_channels, out_channels, **kwargs)
else:
conv_layer = nn.LazyConv1d(out_channels, **kwargs)
block = nn.Sequential(
conv_layer,
nn.BatchNorm1d(out_channels),
get_activation(config_activation)
)
return block
[docs]
def dense_block(
in_features: int,
out_features: int,
config_activation: dict[str, str | dict] | None = None,
**kwargs
) -> Sequential:
r"""
Return a dense block.
The block has the following form::
block = nn.Sequential(
linear_layer,
nn.BatchNorm1d(out_features),
activation_fn,
)
Parameters
----------
in_features : int or None
If :obj:`None`, the ``linear_layer`` is lazy initialized.
out_features : int
config_activation : dict, default=None
Dictionary for configuring activation function. If :obj:`None`, the
:class:`~torch.nn.modules.activation.ReLU` activation is used.
* ``'name'`` activation's class name :class:`str`
* ``'hparams'`` activation's hyperparameters :class:`dict`
**kwargs
Valid keyword arguments for :class:`~torch.nn.Linear`.
Returns
-------
block : torch.nn.Sequential
Examples
--------
>>> inp, out = 3, 10
>>> x = torch.randn(64, inp) # Shape (B, in_features).
>>> config_afn = {'name': 'SELU', 'hparams': {}}
>>> # Default activation function (ReLU).
>>> block = dense_block(inp, out)
>>> block(x).shape
torch.Size([64, 10])
>>> block[2]
ReLU()
>>> # Custom activation function.
>>> block = dense_block(inp, out, config_afn)
>>> block(x).shape
torch.Size([64, 10])
>>> block[2]
SELU()
>>> # Lazy initialized.
>>> block = dense_block(None, 16)
>>> block(x).shape
torch.Size([64, 16])
"""
if in_features is not None:
linear_layer = nn.Linear(in_features, out_features, **kwargs)
else:
linear_layer = nn.LazyLinear(out_features, **kwargs)
block = nn.Sequential(
linear_layer,
nn.BatchNorm1d(out_features),
get_activation(config_activation)
)
return block
[docs]
class TNet(nn.Module):
r"""
Spatial transformer network (STN) from the [PointNet]_ paper for performing
the input and feature transform.
``T-Net`` takes as input a (possibly embedded) point cloud of shape ``(dim, N)``
and regresses a ``(dim, dim)`` matrix. Each point in the point cloud has
shape ``(dim,)``.
The input must be *batched*, i.e. have shape of ``(B, dim, N)``, where ``B`` is
the batch size and ``N`` is the number of points in each point cloud.
Parameters
----------
embed_dim : int
Embedding dimension.
Examples
--------
>>> tnet = TNet(embed_dim=64)
>>> x = torch.randn((128, 64, 42)) # Shape (B, embed_dim, N).
>>> tnet(x).shape
torch.Size([128, 64, 64])
"""
def __init__(self, embed_dim: int) -> None:
super().__init__()
self.embed_dim = embed_dim
self.conv_blocks = nn.Sequential(
conv1d_block(embed_dim, 64, kernel_size=1, bias=False),
conv1d_block(64, 128, kernel_size=1, bias=False),
conv1d_block(128, 1024, kernel_size=1, bias=False),
)
self.dense_blocks = nn.Sequential(
dense_block(1024, 512, bias=False),
dense_block(512, 256, bias=False),
nn.Linear(256, embed_dim * embed_dim),
)
[docs]
def forward(self, x: Tensor) -> Tensor:
r"""
Return the regressed matrices.
Parameters
----------
x : tensor of shape (B, embed_dim, N)
Returns
-------
out : tensor of shape (B, embed_dim, embed_dim)
Regressed matrices.
"""
# Input has shape (B, self.embed_dim, N).
bs = x.shape[0]
x = self.conv_blocks(x)
x, _ = torch.max(x, 2, keepdim=False) # Ignore indices.
x = self.dense_blocks(x)
# Initialize the identity matrix.
identity = torch.eye(self.embed_dim, device=x.device, requires_grad=x.requires_grad).repeat(bs, 1, 1)
# Output has shape (B, self.embed_dim, self.embed_dim).
return x.view(-1, self.embed_dim, self.embed_dim) + identity
[docs]
class PointNetBackbone(nn.Module):
r"""
Backbone of the vanilla version from the [PointNet]_ paper, where
:class:`TNet`'s have been removed.
This module extracts features which can then be passed to a task head for
predictions. This module also returns the *critical indices*.
The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is
the batch size, ``C`` is the number of input channels and ``N`` is the
number of points in each point cloud.
Parameters
----------
n_global_feats : int, default=1024
Number of global features.
local_feats : bool, default=False
If :obj:`True`, the returned features are a concatenation of local features
and global features. Otherwise, the global features are returned.
Examples
--------
>>> feat = PointNetBackbone(2048)
>>> x = torch.randn(32, 4, 200)
>>> features, indices = feat(x, return_indices=True)
>>> features.shape
torch.Size([32, 2048])
>>> indices.shape
torch.Size([32, 2048])
>>> feat = PointNetBackbone(1024, True)
>>> x = torch.randn(16, 4, 100)
>>> features, indices = feat(x, return_indices=True)
>>> features.shape
torch.Size([16, 1088, 100])
>>> indices.shape
torch.Size([16, 1024])
>>> feat = PointNetBackbone(512)
>>> x = torch.randn(8, 3, 50)
>>> feat(x).shape # Only features, no critical indices.
torch.Size([8, 512])
"""
def __init__(
self,
n_global_feats: int = 1024,
local_feats: bool = False
) -> None:
super().__init__()
self.local_feats = local_feats
# First shared MLP.
self.shared_mlp_1 = nn.Sequential(
conv1d_block(None, 64, kernel_size=1, bias=False),
conv1d_block(64, 64, kernel_size=1, bias=False),
)
# Second shared MLP.
self.shared_mlp_2 = nn.Sequential(
conv1d_block(64, 64, kernel_size=1, bias=False),
conv1d_block(64, 128, kernel_size=1, bias=False),
conv1d_block(128, n_global_feats, kernel_size=1, bias=False),
)
[docs]
def forward(
self,
x: Tensor,
return_indices: bool = False
) -> Tensor | tuple[Tensor, Tensor]:
r"""
Return the *features* and optionally *critical indices*.
The type of the features is determined by ``local_feats``.
Parameters
----------
x : tensor of shape (B, C, N)
return_indices : bool, default=False
Whether to return critical indices.
Returns
-------
out : tensor or tuple of tensors
If ``return_indices=False`` the output are the features, otherwise
tuple of the form ``(features, critical_indices)``.
"""
n_points = x.shape[2]
x = self.shared_mlp_1(x)
point_feats = x
x = self.shared_mlp_2(x)
# Shape (B, n_global_feats).
global_feats, critical_indices = torch.max(x, 2, keepdim=False)
out = global_feats
if self.local_feats:
# Shape (B, n_global_feats + 64, N)
out = torch.cat((
point_feats,
global_feats.unsqueeze(-1).repeat(1, 1, n_points)
), dim=1)
if return_indices:
return out, critical_indices
return out
[docs]
class PointNetClsHead(nn.Module):
r"""
Classification head from the [PointNet]_ paper.
.. tip::
This head can be used for classification or regression tasks.
Parameters
----------
n_outputs : int, default=1
dropout_rate : float, default=0
Examples
--------
>>> head = PointNetClsHead(n_outputs=4)
>>> x = torch.randn(64, 13)
>>> head(x).shape
torch.Size([64, 4])
"""
def __init__(self, n_outputs: int = 1, dropout_rate: float = 0) -> None:
super().__init__()
self.mlp = nn.Sequential(
dense_block(None, 512, bias=False),
dense_block(512, 256, bias=False),
nn.Dropout(dropout_rate),
nn.Linear(256, n_outputs),
)
[docs]
def forward(self, x: Tensor) -> Tensor:
r"""
Run the forward pass.
Parameters
----------
x : tensor of shape (B, C)
Returns
-------
tensor of shape (B, n_outputs)
"""
return self.mlp(x)
[docs]
class PointNetSegHead(nn.Module):
r"""
Segmentation head from the [PointNet]_ paper.
.. tip::
This head can be used for segmentation tasks.
Parameters
----------
n_outputs : int, default=1
Examples
--------
>>> head = PointNetSegHead(n_outputs=2)
>>> x = torch.randn(32, 1088, 400)
>>> head(x).shape
torch.Size([32, 400, 2])
"""
def __init__(self, n_outputs: int = 1) -> None:
super().__init__()
self.shared_mlp = nn.Sequential(
conv1d_block(None, 512, kernel_size=1, bias=False),
conv1d_block(512, 256, kernel_size=1, bias=False),
conv1d_block(256, 128, kernel_size=1, bias=False),
nn.Conv1d(128, n_outputs, kernel_size=1),
)
[docs]
def forward(self, x: Tensor) -> Tensor:
r"""
Run the forward pass.
Parameters
----------
x : tensor of shape (B, C, N)
Returns
-------
tensor of shape (B, N, n_outputs)
"""
out = self.shared_mlp(x) # Shape (B, n_outputs, N).
return out.transpose(2, 1) # Shape (B, N, n_outputs).
[docs]
class PointNet(torch.nn.Module):
r"""
Vanilla version from the [PointNet]_ paper where :class:`TNet`'s have been
removed.
:class:`PointNet` takes as input a point cloud and produces one or more
outputs. *The type of the task is determined by* ``head``.
Currently implemented heads include:
1. :class:`.PointNetClsHead`: classification and regression
2. :class:`.PointNetSegHead`: segmentation
The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is
the batch size, ``C`` is the number of input channels and ``N`` is the
number of points in each point cloud.
.. tip::
You can define a ``custom_head`` head as a :class:`torch.nn.Module` and
pass it to ``head``.
If ``local_features=False``, the input to ``custom_head`` must have the
same shape as in :meth:`PointNetClsHead.forward`.
Otherwise, the input to ``custom_head`` must have the same shape as in
:meth:`PointNetSegHead.forward`.
Parameters
----------
head : torch.nn.Module
n_global_feats : int, default=1024
local_feats : bool, default=False
See Also
--------
:class:`.PointNetBackbone` :
For a description of ``local_feats`` and ``n_global_feats``.
Examples
--------
>>> cls_head = PointNetClsHead(n_outputs=2)
>>> seg_head = PointNetSegHead(n_outputs=10)
>>> x = torch.randn(32, 4, 300)
>>> cls_net = PointNet(cls_head, 256)
>>> cls_net(x).shape
torch.Size([32, 2])
>>> cls_net.backbone(x).shape # Only features.
torch.Size([32, 256])
>>> seg_net = PointNet(head=seg_head, n_global_feats=512, local_feats=True)
>>> seg_net(x).shape
torch.Size([32, 300, 10])
>>> seg_net.backbone(x, True)[1].shape # Features and critical indices.
torch.Size([32, 512])
"""
def __init__(
self,
head: Module,
n_global_feats: int = 1024,
local_feats: bool = False
) -> None:
super().__init__()
self.backbone = PointNetBackbone(
n_global_feats=n_global_feats,
local_feats=local_feats,
)
self.head = head
[docs]
def forward(self, x: Tensor) -> Any:
r"""
Run the forward pass.
Parameters
----------
x : tensor of shape (B, C, N)
Returns
-------
Output of ``head``.
"""
return self.head(self.backbone(x))