Source code for aidsorb.modules

# This file is part of AIdsorb.
# Copyright (C) 2024 Antonios P. Sarikas

# AIdsorb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

r"""
:class:`torch.nn.Module`'s for point cloud processing.

.. note::
    :class:`PointNetBackbone`, :class:`PointNetClsHead` and
    :class:`PointNetSegHead` have their initial layers *lazy initialized*, so
    you don't need to specify the input dimensionality.

.. warning::
    It is recommended to **use batched inputs in all cases**. For example, even
    if a single ``pcd`` of shape ``(3+C, N)`` is to be passed to
    :class:`PointNetBackbone`, **reshape it to** ``(1, 3+C, N)``. One way to you
    can do it is the following: ``pcd = pcd.unsqueeze(0)``.

.. todo::
    Add more modules for point cloud processing.

References
----------

.. [PointNet] R. Q. Charles, H. Su, M. Kaichun and L. J. Guibas, "PointNet: Deep
              Learning on Point Sets for 3D Classification and Segmentation," 2017 IEEE
              Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI,
              USA, 2017, pp. 77-85, doi: 10.1109/CVPR.2017.16.
"""

from typing import Any

import torch
from torch import Tensor
from torch import nn
from torch.nn import Module, Sequential

from ._torch_utils import get_activation


[docs] def conv1d_block( in_channels: int, out_channels: int, config_activation: dict[str, str | dict] | None = None, **kwargs ) -> Sequential: r""" Return a 1D convolutional block. The block has the following form:: block = nn.Sequential( conv_layer, nn.BatchNorm1d(out_channels), activation_fn ) Parameters ---------- in_channels : int or None If :obj:`None`, the ``conv_layer`` is lazy initialized. out_channels : int config_activation : dict, default=None Dictionary for configuring activation function. If :obj:`None`, the :class:`~torch.nn.modules.activation.ReLU` activation is used. * ``'name'`` activation's class name :class:`str` * ``'hparams'`` activation's hyperparameters :class:`dict` **kwargs Valid keyword arguments for :class:`~torch.nn.Conv1d`. Returns ------- block : torch.nn.Sequential Examples -------- >>> inp, out = 4, 128 >>> x = torch.randn(32, 4, 100) # Shape (B, in_channels, N). >>> config_afn = {'name': 'LeakyReLU', 'hparams': {'negative_slope': 0.5}} >>> # Default activation function (ReLU). >>> block = conv1d_block(inp, out, kernel_size=1) >>> block(x).shape torch.Size([32, 128, 100]) >>> block[2] ReLU() >>> # Custom activation function. >>> block = conv1d_block(inp, out, config_afn, kernel_size=1) >>> block(x).shape torch.Size([32, 128, 100]) >>> block[2] LeakyReLU(negative_slope=0.5) >>> # Lazy initialized. >>> block = conv1d_block(None, out, kernel_size=1) >>> block(x).shape torch.Size([32, 128, 100]) """ if in_channels is not None: conv_layer = nn.Conv1d(in_channels, out_channels, **kwargs) else: conv_layer = nn.LazyConv1d(out_channels, **kwargs) block = nn.Sequential( conv_layer, nn.BatchNorm1d(out_channels), get_activation(config_activation) ) return block
[docs] def dense_block( in_features: int, out_features: int, config_activation: dict[str, str | dict] | None = None, **kwargs ) -> Sequential: r""" Return a dense block. The block has the following form:: block = nn.Sequential( linear_layer, nn.BatchNorm1d(out_features), activation_fn, ) Parameters ---------- in_features : int or None If :obj:`None`, the ``linear_layer`` is lazy initialized. out_features : int config_activation : dict, default=None Dictionary for configuring activation function. If :obj:`None`, the :class:`~torch.nn.modules.activation.ReLU` activation is used. * ``'name'`` activation's class name :class:`str` * ``'hparams'`` activation's hyperparameters :class:`dict` **kwargs Valid keyword arguments for :class:`~torch.nn.Linear`. Returns ------- block : torch.nn.Sequential Examples -------- >>> inp, out = 3, 10 >>> x = torch.randn(64, inp) # Shape (B, in_features). >>> config_afn = {'name': 'SELU', 'hparams': {}} >>> # Default activation function (ReLU). >>> block = dense_block(inp, out) >>> block(x).shape torch.Size([64, 10]) >>> block[2] ReLU() >>> # Custom activation function. >>> block = dense_block(inp, out, config_afn) >>> block(x).shape torch.Size([64, 10]) >>> block[2] SELU() >>> # Lazy initialized. >>> block = dense_block(None, 16) >>> block(x).shape torch.Size([64, 16]) """ if in_features is not None: linear_layer = nn.Linear(in_features, out_features, **kwargs) else: linear_layer = nn.LazyLinear(out_features, **kwargs) block = nn.Sequential( linear_layer, nn.BatchNorm1d(out_features), get_activation(config_activation) ) return block
[docs] class TNet(nn.Module): r""" Spatial transformer network (STN) from the [PointNet]_ paper for performing the input and feature transform. ``T-Net`` takes as input a (possibly embedded) point cloud of shape ``(dim, N)`` and regresses a ``(dim, dim)`` matrix. Each point in the point cloud has shape ``(dim,)``. The input must be *batched*, i.e. have shape of ``(B, dim, N)``, where ``B`` is the batch size and ``N`` is the number of points in each point cloud. Parameters ---------- embed_dim : int Embedding dimension. Examples -------- >>> tnet = TNet(embed_dim=64) >>> x = torch.randn((128, 64, 42)) # Shape (B, embed_dim, N). >>> tnet(x).shape torch.Size([128, 64, 64]) """ def __init__(self, embed_dim: int) -> None: super().__init__() self.embed_dim = embed_dim self.conv_blocks = nn.Sequential( conv1d_block(embed_dim, 64, kernel_size=1, bias=False), conv1d_block(64, 128, kernel_size=1, bias=False), conv1d_block(128, 1024, kernel_size=1, bias=False), ) self.dense_blocks = nn.Sequential( dense_block(1024, 512, bias=False), dense_block(512, 256, bias=False), nn.Linear(256, embed_dim * embed_dim), )
[docs] def forward(self, x: Tensor) -> Tensor: r""" Return the regressed matrices. Parameters ---------- x : tensor of shape (B, embed_dim, N) Returns ------- out : tensor of shape (B, embed_dim, embed_dim) Regressed matrices. """ # Input has shape (B, self.embed_dim, N). bs = x.shape[0] x = self.conv_blocks(x) x, _ = torch.max(x, 2, keepdim=False) # Ignore indices. x = self.dense_blocks(x) # Initialize the identity matrix. identity = torch.eye(self.embed_dim, device=x.device, requires_grad=x.requires_grad).repeat(bs, 1, 1) # Output has shape (B, self.embed_dim, self.embed_dim). return x.view(-1, self.embed_dim, self.embed_dim) + identity
[docs] class PointNetBackbone(nn.Module): r""" Backbone of the vanilla version from the [PointNet]_ paper, where :class:`TNet`'s have been removed. This module extracts features which can then be passed to a task head for predictions. This module also returns the *critical indices*. The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is the batch size, ``C`` is the number of input channels and ``N`` is the number of points in each point cloud. Parameters ---------- n_global_feats : int, default=1024 Number of global features. local_feats : bool, default=False If :obj:`True`, the returned features are a concatenation of local features and global features. Otherwise, the global features are returned. Examples -------- >>> feat = PointNetBackbone(2048) >>> x = torch.randn(32, 4, 200) >>> features, indices = feat(x, return_indices=True) >>> features.shape torch.Size([32, 2048]) >>> indices.shape torch.Size([32, 2048]) >>> feat = PointNetBackbone(1024, True) >>> x = torch.randn(16, 4, 100) >>> features, indices = feat(x, return_indices=True) >>> features.shape torch.Size([16, 1088, 100]) >>> indices.shape torch.Size([16, 1024]) >>> feat = PointNetBackbone(512) >>> x = torch.randn(8, 3, 50) >>> feat(x).shape # Only features, no critical indices. torch.Size([8, 512]) """ def __init__( self, n_global_feats: int = 1024, local_feats: bool = False ) -> None: super().__init__() self.local_feats = local_feats # First shared MLP. self.shared_mlp_1 = nn.Sequential( conv1d_block(None, 64, kernel_size=1, bias=False), conv1d_block(64, 64, kernel_size=1, bias=False), ) # Second shared MLP. self.shared_mlp_2 = nn.Sequential( conv1d_block(64, 64, kernel_size=1, bias=False), conv1d_block(64, 128, kernel_size=1, bias=False), conv1d_block(128, n_global_feats, kernel_size=1, bias=False), )
[docs] def forward( self, x: Tensor, return_indices: bool = False ) -> Tensor | tuple[Tensor, Tensor]: r""" Return the *features* and optionally *critical indices*. The type of the features is determined by ``local_feats``. Parameters ---------- x : tensor of shape (B, C, N) return_indices : bool, default=False Whether to return critical indices. Returns ------- out : tensor or tuple of tensors If ``return_indices=False`` the output are the features, otherwise tuple of the form ``(features, critical_indices)``. """ n_points = x.shape[2] x = self.shared_mlp_1(x) point_feats = x x = self.shared_mlp_2(x) # Shape (B, n_global_feats). global_feats, critical_indices = torch.max(x, 2, keepdim=False) out = global_feats if self.local_feats: # Shape (B, n_global_feats + 64, N) out = torch.cat(( point_feats, global_feats.unsqueeze(-1).repeat(1, 1, n_points) ), dim=1) if return_indices: return out, critical_indices return out
[docs] class PointNetClsHead(nn.Module): r""" Classification head from the [PointNet]_ paper. .. tip:: This head can be used for classification or regression tasks. Parameters ---------- n_outputs : int, default=1 dropout_rate : float, default=0 Examples -------- >>> head = PointNetClsHead(n_outputs=4) >>> x = torch.randn(64, 13) >>> head(x).shape torch.Size([64, 4]) """ def __init__(self, n_outputs: int = 1, dropout_rate: float = 0) -> None: super().__init__() self.mlp = nn.Sequential( dense_block(None, 512, bias=False), dense_block(512, 256, bias=False), nn.Dropout(dropout_rate), nn.Linear(256, n_outputs), )
[docs] def forward(self, x: Tensor) -> Tensor: r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C) Returns ------- out : tensor of shape (B, n_outputs) """ return self.mlp(x)
[docs] class PointNetSegHead(nn.Module): r""" Segmentation head from the [PointNet]_ paper. .. tip:: This head can be used for segmentation tasks. Parameters ---------- n_outputs : int, default=1 Examples -------- >>> head = PointNetSegHead(n_outputs=2) >>> x = torch.randn(32, 1088, 400) >>> head(x).shape torch.Size([32, 400, 2]) """ def __init__(self, n_outputs: int = 1) -> None: super().__init__() self.shared_mlp = nn.Sequential( conv1d_block(None, 512, kernel_size=1, bias=False), conv1d_block(512, 256, kernel_size=1, bias=False), conv1d_block(256, 128, kernel_size=1, bias=False), nn.Conv1d(128, n_outputs, kernel_size=1), )
[docs] def forward(self, x: Tensor) -> Tensor: r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C, N). Returns ------- out : tensor of shape (B, N, n_outputs) """ out = self.shared_mlp(x) # Shape (B, n_outputs, N). return out.transpose(2, 1) # Shape (B, N, n_outputs).
[docs] class PointNet(torch.nn.Module): r""" Vanilla version from the [PointNet]_ paper where :class:`TNet`'s have been removed. :class:`PointNet` takes as input a point cloud and produces one or more outputs. *The type of the task is determined by* ``head``. Currently implemented heads include: 1. :class:`.PointNetClsHead`: classification and regression 2. :class:`.PointNetSegHead`: segmentation The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is the batch size, ``C`` is the number of input channels and ``N`` is the number of points in each point cloud. .. tip:: You can define a ``custom_head`` head as a :class:`torch.nn.Module` and pass it to ``head``. If ``local_features=False``, the input to ``custom_head`` must have the same shape as in :meth:`PointNetClsHead.forward`. Otherwise, the input to ``custom_head`` must have the same shape as in :meth:`PointNetSegHead.forward`. Parameters ---------- head : torch.nn.Module n_global_feats : int, default=1024 local_feats : bool, default=False See Also -------- :class:`modules.PointNetBackbone` : For a description of ``local_feats`` and ``n_global_feats``. Examples -------- >>> cls_head = PointNetClsHead(n_outputs=2) >>> seg_head = PointNetSegHead(n_outputs=10) >>> x = torch.randn(32, 4, 300) >>> cls_net = PointNet(cls_head, 256) >>> cls_net(x).shape torch.Size([32, 2]) >>> cls_net.backbone(x).shape # Only features. torch.Size([32, 256]) >>> seg_net = PointNet(head=seg_head, n_global_feats=512, local_feats=True) >>> seg_net(x).shape torch.Size([32, 300, 10]) >>> seg_net.backbone(x, True)[1].shape # Features and critical indices. torch.Size([32, 512]) """ def __init__( self, head: Module, n_global_feats: int = 1024, local_feats: bool = False ) -> None: super().__init__() self.backbone = PointNetBackbone( n_global_feats=n_global_feats, local_feats=local_feats, ) self.head = head
[docs] def forward(self, x: Tensor) -> Any: r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C, N) Returns ------- out : tensor Output of ``head``. """ return self.head(self.backbone(x))