Source code for aidsorb.transforms.voxels

# This file is part of AIdsorb.
# Copyright (C) 2026 Antonios P. Sarikas

# AIdsorb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

r"""
Helper functions and classes for transforming voxels.

.. note::
    All geometric transforms expect an input :class:`~torch.Tensor` of shape ``(C, D, H, W)``.
"""

from itertools import combinations

import torch


def _check_shape(obj):
    r"""
    Check if ``obj`` has valid shape to be considered voxels.

    Parameters
    ----------
    obj : array or tensor

    Examples
    --------
    >>> x = torch.randn(2, 2, 2, 2)
    >>> _check_shape(x)
    >>> x = torch.randn(2, 2)
    >>> _check_shape(x)
    Traceback (most recent call last):
    ...
    ValueError: expecting shape (C, D, H, W) but received shape (2, 2)

    Raises
    ------
    ValueError
        If ``obj.shape != (N, 3+C)``.
    """
    if not (obj.ndim == 4):
        raise ValueError(
                'expecting shape (C, D, H, W) '
                f'but received shape {tuple(obj.shape)}'
                )


[docs] class AddChannelDim: r""" Prepend a dimension to the input tensor. Examples -------- >>> x = torch.randn(32, 32, 32) >>> AddChannelDim()(x).shape torch.Size([1, 32, 32, 32]) """ def __call__(self, x): return x.unsqueeze(0)
[docs] class BoltzmannFactor: r""" Fill voxels with the Boltzmann factor. Parameters ---------- temperature : float Examples -------- >>> x = torch.tensor([0., torch.inf]) >>> BoltzmannFactor()(x) tensor([1., 0.]) """ def __init__(self, temperature: float = 298.): self.temperature = temperature def __call__(self, x): return torch.exp((-1 / self.temperature) * x)
[docs] class ClipVoxels: r""" Clip voxels within ``[vmin, vmax]``. Parameters ---------- vmin : float vmax : float Examples -------- >>> x = torch.tensor([-20., 22.]) >>> out = ClipVoxels(-1, 1)(x) >>> out tensor([-1., 1.]) """ def __init__(self, vmin: float, vmax: float): self.vmin = vmin self.vmax = vmax def __call__(self, x): return torch.clip(x, self.vmin, self.vmax)
[docs] class ClipScaleVoxels: r""" Clip and then normalize voxels within ``[-1, 1]``. First clips voxels within ``[-value, value]``, then divides the result by ``value``, producing voxels with values in ``[-1, 1]``. Parameters ---------- value : float Examples -------- >>> x = torch.tensor([-12., 11.]) >>> ClipScaleVoxels(10)(x) tensor([-1., 1.]) """ def __init__(self, value: float = 5e3): self.value = abs(value) def __call__(self, x: torch.Tensor) -> torch.Tensor: x_clipped = torch.clamp(x, -self.value, self.value) x_scaled = x_clipped / self.value # Scale to [-1, 1] return x_scaled
[docs] class RandomNoise: r""" Add normal noise to voxels. Parameters ---------- std : float Standard deviation of the normal noise. Examples -------- >>> x = torch.randn(3, 3) >>> out = RandomNoise(0.1)(x) >>> out.shape torch.Size([3, 3]) >>> torch.equal(x, out) False """ def __init__(self, std): self.std = std def __call__(self, x): noise = torch.randn(x.shape, device=x.device) * self.std return x + noise
[docs] class RandomRotation90: r""" Rotate voxels around a randomly chosen axis by 90 degrees. Examples -------- >>> x = torch.randn(2, 3, 3, 3) >>> out = RandomRotation90()(x) >>> out.shape torch.Size([2, 3, 3, 3]) >>> torch.equal(x, out) False """ def __init__(self): self.planes = list(combinations([1, 2, 3], 2)) self.directions = torch.tensor([-1, 1]) def __call__(self, x): _check_shape(x) p_choice = torch.randint(len(self.planes), ()).item() plane = self.planes[p_choice] d_choice = torch.randint(len(self.directions), ()).item() direction = self.directions[d_choice] return torch.rot90(x, k=direction, dims=plane)
[docs] class RandomFlip: r""" Flip voxels along a randomly chosen axis. Examples -------- >>> x = torch.randn(2, 3, 3, 3) >>> out = RandomFlip()(x) >>> out.shape torch.Size([2, 3, 3, 3]) >>> torch.equal(x, out) False """ def __call__(self, x): _check_shape(x) dim = torch.randint(1, 4, ()).item() return torch.flip(x, [dim])
[docs] class RandomReflect: r""" Reflect voxels along a randomly chosen plane. Examples -------- >>> x = torch.randn(2, 3, 3, 3) >>> out = RandomReflect()(x) >>> out.shape torch.Size([2, 3, 3, 3]) >>> torch.equal(x, out) False """ def __init__(self): self.planes = list(combinations([1, 2, 3], 2)) def __call__(self, x): _check_shape(x) p_choice = torch.randint(len(self.planes), ()).item() plane = self.planes[p_choice] return torch.transpose(x, *plane)