Source code for aidsorb.modules.points

# This file is part of AIdsorb.
# Copyright (C) 2024 Antonios P. Sarikas

# AIdsorb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

r"""
:class:`torch.nn.Module`'s for point cloud processing.

.. note::
    :class:`PointNetBackbone`, :class:`PointNetClsHead` and
    :class:`PointNetSegHead` have their initial layers *lazy initialized*, so
    you don't need to specify the input dimensionality.

.. warning::
    It is recommended to **use batched inputs in all cases**. For example, even
    if a single ``pcd`` of shape ``(3+C, N)`` is to be passed to
    :class:`PointNetBackbone`, **reshape it to** ``(1, 3+C, N)``. One way you
    can do it is the following: ``pcd = pcd.unsqueeze(0)``.

.. todo::
    Add more modules for point cloud processing.

References
----------

.. [PointNet] Charles, R. Q., Su, H., Kaichun, M., & Guibas, L. J. (2017).
              PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation. In
              2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 77–85).
              IEEE. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
              https://doi.org/10.1109/cvpr.2017.16
"""

from typing import Any

import torch
from torch import Tensor
from torch import nn
from torch.nn import Module, Sequential

from .._torch_utils import get_activation


[docs] def conv1d_block( in_channels: int, out_channels: int, config_activation: dict[str, str | dict] | None = None, **kwargs ) -> Sequential: r""" Return a 1D convolutional block. The block has the following form:: block = nn.Sequential( conv_layer, nn.BatchNorm1d(out_channels), activation_fn ) Parameters ---------- in_channels : int or None If :obj:`None`, the ``conv_layer`` is lazy initialized. out_channels : int config_activation : dict, default=None Dictionary for configuring activation function. If :obj:`None`, the :class:`~torch.nn.modules.activation.ReLU` activation is used. * ``'name'`` activation's class name :class:`str` * ``'hparams'`` activation's hyperparameters :class:`dict` **kwargs Valid keyword arguments for :class:`~torch.nn.Conv1d`. Returns ------- block : torch.nn.Sequential Examples -------- >>> inp, out = 4, 128 >>> x = torch.randn(32, 4, 100) # Shape (B, in_channels, N). >>> config_afn = {'name': 'LeakyReLU', 'hparams': {'negative_slope': 0.5}} >>> # Default activation function (ReLU). >>> block = conv1d_block(inp, out, kernel_size=1) >>> block(x).shape torch.Size([32, 128, 100]) >>> block[2] ReLU() >>> # Custom activation function. >>> block = conv1d_block(inp, out, config_afn, kernel_size=1) >>> block(x).shape torch.Size([32, 128, 100]) >>> block[2] LeakyReLU(negative_slope=0.5) >>> # Lazy initialized. >>> block = conv1d_block(None, out, kernel_size=1) >>> block(x).shape torch.Size([32, 128, 100]) """ if in_channels is not None: conv_layer = nn.Conv1d(in_channels, out_channels, **kwargs) else: conv_layer = nn.LazyConv1d(out_channels, **kwargs) block = nn.Sequential( conv_layer, nn.BatchNorm1d(out_channels), get_activation(config_activation) ) return block
[docs] def dense_block( in_features: int, out_features: int, config_activation: dict[str, str | dict] | None = None, **kwargs ) -> Sequential: r""" Return a dense block. The block has the following form:: block = nn.Sequential( linear_layer, nn.BatchNorm1d(out_features), activation_fn, ) Parameters ---------- in_features : int or None If :obj:`None`, the ``linear_layer`` is lazy initialized. out_features : int config_activation : dict, default=None Dictionary for configuring activation function. If :obj:`None`, the :class:`~torch.nn.modules.activation.ReLU` activation is used. * ``'name'`` activation's class name :class:`str` * ``'hparams'`` activation's hyperparameters :class:`dict` **kwargs Valid keyword arguments for :class:`~torch.nn.Linear`. Returns ------- block : torch.nn.Sequential Examples -------- >>> inp, out = 3, 10 >>> x = torch.randn(64, inp) # Shape (B, in_features). >>> config_afn = {'name': 'SELU', 'hparams': {}} >>> # Default activation function (ReLU). >>> block = dense_block(inp, out) >>> block(x).shape torch.Size([64, 10]) >>> block[2] ReLU() >>> # Custom activation function. >>> block = dense_block(inp, out, config_afn) >>> block(x).shape torch.Size([64, 10]) >>> block[2] SELU() >>> # Lazy initialized. >>> block = dense_block(None, 16) >>> block(x).shape torch.Size([64, 16]) """ if in_features is not None: linear_layer = nn.Linear(in_features, out_features, **kwargs) else: linear_layer = nn.LazyLinear(out_features, **kwargs) block = nn.Sequential( linear_layer, nn.BatchNorm1d(out_features), get_activation(config_activation) ) return block
[docs] class TNet(nn.Module): r""" Spatial transformer network (STN) from the [PointNet]_ paper for performing the input and feature transform. ``T-Net`` takes as input a (possibly embedded) point cloud of shape ``(dim, N)`` and regresses a ``(dim, dim)`` matrix. Each point in the point cloud has shape ``(dim,)``. The input must be *batched*, i.e. have shape of ``(B, dim, N)``, where ``B`` is the batch size and ``N`` is the number of points in each point cloud. Parameters ---------- embed_dim : int Embedding dimension. Examples -------- >>> tnet = TNet(embed_dim=64) >>> x = torch.randn((128, 64, 42)) # Shape (B, embed_dim, N). >>> tnet(x).shape torch.Size([128, 64, 64]) """ def __init__(self, embed_dim: int) -> None: super().__init__() self.embed_dim = embed_dim self.conv_blocks = nn.Sequential( conv1d_block(embed_dim, 64, kernel_size=1, bias=False), conv1d_block(64, 128, kernel_size=1, bias=False), conv1d_block(128, 1024, kernel_size=1, bias=False), ) self.dense_blocks = nn.Sequential( dense_block(1024, 512, bias=False), dense_block(512, 256, bias=False), nn.Linear(256, embed_dim * embed_dim), )
[docs] def forward(self, x: Tensor) -> Tensor: r""" Return the regressed matrices. Parameters ---------- x : tensor of shape (B, embed_dim, N) Returns ------- out : tensor of shape (B, embed_dim, embed_dim) Regressed matrices. """ # Input has shape (B, self.embed_dim, N). bs = x.shape[0] x = self.conv_blocks(x) x, _ = torch.max(x, 2, keepdim=False) # Ignore indices. x = self.dense_blocks(x) # Initialize the identity matrix. identity = torch.eye(self.embed_dim, device=x.device, requires_grad=x.requires_grad).repeat(bs, 1, 1) # Output has shape (B, self.embed_dim, self.embed_dim). return x.view(-1, self.embed_dim, self.embed_dim) + identity
[docs] class PointNetBackbone(nn.Module): r""" Backbone of the vanilla version from the [PointNet]_ paper, where :class:`TNet`'s have been removed. This module extracts features which can then be passed to a task head for predictions. This module also returns the *critical indices*. The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is the batch size, ``C`` is the number of input channels and ``N`` is the number of points in each point cloud. Parameters ---------- n_global_feats : int, default=1024 Number of global features. local_feats : bool, default=False If :obj:`True`, the returned features are a concatenation of local features and global features. Otherwise, the global features are returned. Examples -------- >>> feat = PointNetBackbone(2048) >>> x = torch.randn(32, 4, 200) >>> features, indices = feat(x, return_indices=True) >>> features.shape torch.Size([32, 2048]) >>> indices.shape torch.Size([32, 2048]) >>> feat = PointNetBackbone(1024, True) >>> x = torch.randn(16, 4, 100) >>> features, indices = feat(x, return_indices=True) >>> features.shape torch.Size([16, 1088, 100]) >>> indices.shape torch.Size([16, 1024]) >>> feat = PointNetBackbone(512) >>> x = torch.randn(8, 3, 50) >>> feat(x).shape # Only features, no critical indices. torch.Size([8, 512]) """ def __init__( self, n_global_feats: int = 1024, local_feats: bool = False ) -> None: super().__init__() self.local_feats = local_feats # First shared MLP. self.shared_mlp_1 = nn.Sequential( conv1d_block(None, 64, kernel_size=1, bias=False), conv1d_block(64, 64, kernel_size=1, bias=False), ) # Second shared MLP. self.shared_mlp_2 = nn.Sequential( conv1d_block(64, 64, kernel_size=1, bias=False), conv1d_block(64, 128, kernel_size=1, bias=False), conv1d_block(128, n_global_feats, kernel_size=1, bias=False), )
[docs] def forward( self, x: Tensor, return_indices: bool = False ) -> Tensor | tuple[Tensor, Tensor]: r""" Return the *features* and optionally *critical indices*. The type of the features is determined by ``local_feats``. Parameters ---------- x : tensor of shape (B, C, N) return_indices : bool, default=False Whether to return critical indices. Returns ------- out : tensor or tuple of tensors If ``return_indices=False`` the output are the features, otherwise tuple of the form ``(features, critical_indices)``. """ n_points = x.shape[2] x = self.shared_mlp_1(x) point_feats = x x = self.shared_mlp_2(x) # Shape (B, n_global_feats). global_feats, critical_indices = torch.max(x, 2, keepdim=False) out = global_feats if self.local_feats: # Shape (B, n_global_feats + 64, N) out = torch.cat(( point_feats, global_feats.unsqueeze(-1).repeat(1, 1, n_points) ), dim=1) if return_indices: return out, critical_indices return out
[docs] class PointNetClsHead(nn.Module): r""" Classification head from the [PointNet]_ paper. .. tip:: This head can be used for classification or regression tasks. Parameters ---------- n_outputs : int, default=1 dropout_rate : float, default=0 Examples -------- >>> head = PointNetClsHead(n_outputs=4) >>> x = torch.randn(64, 13) >>> head(x).shape torch.Size([64, 4]) """ def __init__(self, n_outputs: int = 1, dropout_rate: float = 0) -> None: super().__init__() self.mlp = nn.Sequential( dense_block(None, 512, bias=False), dense_block(512, 256, bias=False), nn.Dropout(dropout_rate), nn.Linear(256, n_outputs), )
[docs] def forward(self, x: Tensor) -> Tensor: r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C) Returns ------- tensor of shape (B, n_outputs) """ return self.mlp(x)
[docs] class PointNetSegHead(nn.Module): r""" Segmentation head from the [PointNet]_ paper. .. tip:: This head can be used for segmentation tasks. Parameters ---------- n_outputs : int, default=1 Examples -------- >>> head = PointNetSegHead(n_outputs=2) >>> x = torch.randn(32, 1088, 400) >>> head(x).shape torch.Size([32, 400, 2]) """ def __init__(self, n_outputs: int = 1) -> None: super().__init__() self.shared_mlp = nn.Sequential( conv1d_block(None, 512, kernel_size=1, bias=False), conv1d_block(512, 256, kernel_size=1, bias=False), conv1d_block(256, 128, kernel_size=1, bias=False), nn.Conv1d(128, n_outputs, kernel_size=1), )
[docs] def forward(self, x: Tensor) -> Tensor: r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C, N) Returns ------- tensor of shape (B, N, n_outputs) """ out = self.shared_mlp(x) # Shape (B, n_outputs, N). return out.transpose(2, 1) # Shape (B, N, n_outputs).
[docs] class PointNet(torch.nn.Module): r""" Vanilla version from the [PointNet]_ paper where :class:`TNet`'s have been removed. :class:`PointNet` takes as input a point cloud and produces one or more outputs. *The type of the task is determined by* ``head``. Currently implemented heads include: 1. :class:`.PointNetClsHead`: classification and regression 2. :class:`.PointNetSegHead`: segmentation The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is the batch size, ``C`` is the number of input channels and ``N`` is the number of points in each point cloud. .. tip:: You can define a ``custom_head`` head as a :class:`torch.nn.Module` and pass it to ``head``. If ``local_features=False``, the input to ``custom_head`` must have the same shape as in :meth:`PointNetClsHead.forward`. Otherwise, the input to ``custom_head`` must have the same shape as in :meth:`PointNetSegHead.forward`. Parameters ---------- head : torch.nn.Module n_global_feats : int, default=1024 local_feats : bool, default=False See Also -------- :class:`.PointNetBackbone` : For a description of ``local_feats`` and ``n_global_feats``. Examples -------- >>> cls_head = PointNetClsHead(n_outputs=2) >>> seg_head = PointNetSegHead(n_outputs=10) >>> x = torch.randn(32, 4, 300) >>> cls_net = PointNet(cls_head, 256) >>> cls_net(x).shape torch.Size([32, 2]) >>> cls_net.backbone(x).shape # Only features. torch.Size([32, 256]) >>> seg_net = PointNet(head=seg_head, n_global_feats=512, local_feats=True) >>> seg_net(x).shape torch.Size([32, 300, 10]) >>> seg_net.backbone(x, True)[1].shape # Features and critical indices. torch.Size([32, 512]) """ def __init__( self, head: Module, n_global_feats: int = 1024, local_feats: bool = False ) -> None: super().__init__() self.backbone = PointNetBackbone( n_global_feats=n_global_feats, local_feats=local_feats, ) self.head = head
[docs] def forward(self, x: Tensor) -> Any: r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C, N) Returns ------- Output of ``head``. """ return self.head(self.backbone(x))