aidsorb.litmodels

This module provides LightningModule’s for use with PyTorch Lightning.

class aidsorb.litmodels.PointLit(model, loss, metric, lr=0.001)[source]

Bases: LightningModule

LightningModule for aidsorb.models.

Parameters:
  • model (torch.nn.Module) – Currently, only PointNet is available.

  • loss (callable) – The loss function to be optimized during training. For valid options, see loss functions.

  • metric (torchmetrics.MetricCollection) –

    The performance metric(s) to be logged and optionally monitored.

    Note

    The metric is logged on epoch-level.

    Tip

    You can use 'val_<MetricName>' as the quantity to monitor. For example, if metric=MetricCollection(R2Score(), MeanAbsoluteError()) and you want to monitor R2Score, configure the ModelCheckpoint as following:

    from lightning.pytorch.callbacks import ModelCheckpoint
    
    checkpoint_callback = ModelCheckpoint(monitor='val_R2Score', mode='max', ...)
    

  • lr (float, default=0.001) – The learning rate for Adam optimizer.

Examples

>>> from aidsorb.modules import PointNetClsHead
>>> from aidsorb.models import PointNet
>>> from torch.nn import MSELoss
>>> from torchmetrics import MetricCollection, R2Score, MeanAbsoluteError as MAE
>>> model = PointNet(head=PointNetClsHead(n_outputs=10))
>>> loss, metric = MSELoss(), MetricCollection(R2Score(), MAE())
>>> litmodel = PointLit(model=model, loss=loss, metric=metric)
>>> x = torch.randn(32, 5, 100)
>>> out = litmodel(x)
>>> out.shape
torch.Size([32, 10])
configure_optimizers()[source]

Return the optimizer.

forward(x)[source]

Run forward pass (forward method) of model.

predict_step(batch, batch_idx)[source]

Return predictions on a single batch.

test_step(batch, batch_idx)[source]

Make predictions on a single batch from the test set for epoch-level operations.

training_step(batch, batch_idx)[source]

Compute and return training loss on a single batch from the train set.

Also, make predictions that will be used on epoch-level operations.

Note

Inference mode is enabled during predictions, so an accurate estimate of training performance (e.g. when using Dropout) is reported.

validation_step(batch, batch_idx)[source]

Make predictions on a single batch from the validation set for epoch-level operations.