Source code for aidsorb.modules

# This file is part of AIdsorb.
# Copyright (C) 2024 Antonios P. Sarikas

# AIdsorb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

r"""
This module provides :class:`torch.nn.Module`'s for building the architectures
in :class:`aidsorb.models`.

Currently, the module provides the basic blocks for building the architecture
from the [PointNet]_ paper.

.. note::
    :class:`PointNetBackbone`, :class:`PointNetClsHead` and
    :class:`PointNetSegHead` have their initial layers *lazy initialized*, so
    you don't need to specify the input dimensionality.

.. warning::
    It is recommended to **use batched inputs in all cases**. For example, even
    if a single ``pcd`` of shape ``(3+C, N)`` is to be passed to
    :class:`PointNetBackbone`, **reshape it to** ``(1, 3+C, N)``. One way to you
    can do it is the following: ``pcd = pcd.unsqueeze(0)``.
"""

import torch
from torch import nn


[docs] def conv1d_block(in_channels, out_channels, **kwargs): r""" Return a 1D convolutional block. The block has the following form:: block = nn.Sequential( conv_layer, nn.BatchNorm1d(out_channels), nn.ReLU(), ) Parameters ---------- in_channels : int or None If ``None``, the ``conv_layer`` is lazy initialized. out_channels : int **kwargs Valid keyword arguments for :class:`torch.nn.Conv1d`. Returns ------- block : :class:`torch.nn.Sequential` See Also -------- :class:`torch.nn.Conv1d` : For a description of the parameters. Examples -------- >>> x = torch.randn(32, 4, 100) # Shape (B, C_in, N). >>> block = conv1d_block(4, 128, kernel_size=1) >>> block(x).shape # Shape (B, C_out, N). torch.Size([32, 128, 100]) >>> # Lazy initialized. >>> block = conv1d_block(None, 16, kernel_size=1) >>> block(x).shape torch.Size([32, 16, 100]) """ if in_channels is not None: conv_layer = nn.Conv1d(in_channels, out_channels, **kwargs) else: conv_layer = nn.LazyConv1d(out_channels, **kwargs) block = nn.Sequential( conv_layer, nn.BatchNorm1d(out_channels), nn.ReLU(), ) return block
[docs] def dense_block(in_features, out_features, **kwargs): r""" Return a dense block. The block has the following form:: block = nn.Sequential( linear_layer, nn.BatchNorm1d(out_features), nn.ReLU(), ) Parameters ---------- in_features : int or None If ``None``, the ``linear_layer`` is lazy initialized. out_features : int **kwargs Valid keyword arguments for :class:`torch.nn.Linear`. Returns ------- block : :class:`torch.nn.Sequential` See Also -------- :class:`torch.nn.Linear` : For a description of the parameters. Examples -------- >>> x = torch.randn(64, 3) # Shape (B, in_features). >>> block = dense_block(3, 10) >>> block(x).shape # Shape (B, out_features). torch.Size([64, 10]) >>> # Lazy initialized. >>> block = dense_block(None, 16) >>> block(x).shape torch.Size([64, 16]) """ if in_features is not None: linear_layer = nn.Linear(in_features, out_features, **kwargs) else: linear_layer = nn.LazyLinear(out_features, **kwargs) block = nn.Sequential( linear_layer, nn.BatchNorm1d(out_features), nn.ReLU(), ) return block
[docs] class TNet(nn.Module): r""" Spatial transformer network (STN) from the [PointNet]_ paper for performing the input and feature transform. ``T-Net`` takes as input a (possibly embedded) point cloud of shape ``(dim, N)`` and regresses a ``(dim, dim)`` matrix. Each point in the point cloud has shape ``(dim,)``. The input must be *batched*, i.e. have shape of ``(B, dim, N)``, where ``B`` is the batch size and ``N`` is the number of points in each point cloud. Parameters ---------- embed_dim : int The embedding dimension. Examples -------- >>> tnet = TNet(embed_dim=64) >>> x = torch.randn((128, 64, 42)) # Shape (B, embed_dim, N). >>> tnet(x).shape torch.Size([128, 64, 64]) """ def __init__(self, embed_dim): super().__init__() self.embed_dim = embed_dim self.conv_blocks = nn.Sequential( conv1d_block(embed_dim, 64, kernel_size=1, bias=False), conv1d_block(64, 128, kernel_size=1, bias=False), conv1d_block(128, 1024, kernel_size=1, bias=False), ) self.dense_blocks = nn.Sequential( dense_block(1024, 512, bias=False), dense_block(512, 256, bias=False), nn.Linear(256, embed_dim * embed_dim), )
[docs] def forward(self, x): r""" Return the regressed matrices. Parameters ---------- x : tensor of shape (B, embed_dim, N) Returns ------- out : tensor of shape (B, embed_dim, embed_dim) The regressed matrices. """ # Input has shape (B, self.embed_dim, N). bs = x.shape[0] x = self.conv_blocks(x) x, _ = torch.max(x, 2, keepdim=False) # Ignore indices. x = self.dense_blocks(x) # Initialize the identity matrix. identity = torch.eye(self.embed_dim, device=x.device, requires_grad=True).repeat(bs, 1, 1) # Output has shape (B, self.embed_dim, self.embed_dim). x = x.view(-1, self.embed_dim, self.embed_dim) + identity return x
[docs] class PointNetBackbone(nn.Module): r""" Backbone of the :class:`PointNet` model. This block is responsible for obtaining the *local and global features*, which can then be passed to a task head for predictions. This block also returns the *critical indices*. The input must be *batched*, i.e. have shape of ``(B, C, N)`` where ``B`` is the batch size, ``C`` is the number of input channels and ``N`` is the number of points in each point cloud. Parameters ---------- local_feats : bool, default=False If ``True``, the returned features are a concatenation of local features and global features. Otherwise, the global features are returned. n_global_feats : int, default=1024 The number of global features. Examples -------- >>> feat = PointNetBackbone(n_global_feats=2048) >>> x = torch.randn((32, 4, 200)) >>> features, indices = feat(x) >>> features.shape torch.Size([32, 2048]) >>> indices.shape torch.Size([32, 2048]) >>> feat = PointNetBackbone(local_feats=True, n_global_feats=1024) >>> x = torch.randn((16, 4, 100)) >>> features, indices = feat(x) >>> features.shape torch.Size([16, 1088, 100]) >>> indices.shape torch.Size([16, 1024]) """ def __init__(self, local_feats=False, n_global_feats=1024): super().__init__() self.local_feats = local_feats # First shared MLP. self.shared_mlp_1 = nn.Sequential( conv1d_block(None, 64, kernel_size=1, bias=False), conv1d_block(64, 64, kernel_size=1, bias=False), ) # Second shared MLP. self.shared_mlp_2 = nn.Sequential( conv1d_block(64, 64, kernel_size=1, bias=False), conv1d_block(64, 128, kernel_size=1, bias=False), conv1d_block(128, n_global_feats, kernel_size=1, bias=False), )
[docs] def forward(self, x): r""" Return the *features* and *critical indices*. The type of the features is determined by ``local_feats``. Parameters ---------- x : tensor of shape (B, C, N) Returns ------- out : tuple of length 2 * ``out[0] == features`` * ``out[1] == critical_indices`` """ n_points = x.shape[2] x = self.shared_mlp_1(x) if self.local_feats: point_feats = x.clone() x = self.shared_mlp_2(x) # Shape (B, n_global_feats). global_feats, critical_indices = torch.max(x, 2, keepdim=False) if self.local_feats: # Shape (B, n_global_feats + 64, N) feats = torch.cat( (point_feats, global_feats.unsqueeze(-1).repeat(1, 1, n_points)), dim=1 ) return feats, critical_indices return global_feats, critical_indices
[docs] class PointNetClsHead(nn.Module): r""" Classification head from the [PointNet]_ paper. .. note:: This head can be used for classification or regression. Parameters ---------- n_outputs : int, default=1 dropout_rate : float, default=0 Examples -------- >>> head = PointNetClsHead(n_outputs=4) >>> x = torch.randn(64, 13) >>> head(x).shape torch.Size([64, 4]) """ def __init__(self, n_outputs=1, dropout_rate=0): super().__init__() self.mlp = nn.Sequential( dense_block(None, 512, bias=False), dense_block(512, 256, bias=False), nn.Dropout(dropout_rate), nn.Linear(256, n_outputs), )
[docs] def forward(self, x): r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C) Returns ------- out : tensor of shape (B, n_outputs) """ x = self.mlp(x) return x
[docs] class PointNetSegHead(nn.Module): r""" Segmentation head from the [PointNet]_ paper. .. note:: This head can be used for segmentation. Parameters ---------- n_outputs : int, default=1 Examples -------- >>> head = PointNetSegHead(n_outputs=2) >>> x = torch.randn(32, 1088, 400) >>> head(x).shape torch.Size([32, 400, 2]) """ def __init__(self, n_outputs=1): super().__init__() self.shared_mlp = nn.Sequential( conv1d_block(None, 512, kernel_size=1, bias=False), conv1d_block(512, 256, kernel_size=1, bias=False), conv1d_block(256, 128, kernel_size=1, bias=False), nn.Conv1d(128, n_outputs, kernel_size=1), )
[docs] def forward(self, x): r""" Run the forward pass. Parameters ---------- x : tensor of shape (B, C, N). Returns ------- out : tensor of shape (B, N, n_outputs) """ x = self.shared_mlp(x) # Shape (B, n_outputs, N). x = x.transpose(2, 1) # Shape (B, N, n_outputs). return x