aidsorb.litmodels
This module provides LightningModule’s for use
with PyTorch Lightning.
- class aidsorb.litmodels.PointLit(model, loss, metric, lr=0.001)[source]
Bases:
LightningModuleLightningModuleforaidsorb.models.- Parameters:
model (
torch.nn.Module) – Currently, onlyPointNetis available.loss (callable) – The loss function to be optimized during training. For valid options, see loss functions.
metric (
torchmetrics.MetricCollection) –The performance metric(s) to be logged and optionally monitored.
Note
The
metricis logged on epoch-level.Tip
You can use
'val_<MetricName>'as the quantity to monitor. For example, ifmetric=MetricCollection(R2Score(), MeanAbsoluteError())and you want to monitorR2Score, configure theModelCheckpointas following:from lightning.pytorch.callbacks import ModelCheckpoint checkpoint_callback = ModelCheckpoint(monitor='val_R2Score', mode='max', ...)
lr (float, default=0.001) – The learning rate for
Adamoptimizer.
Examples
>>> from aidsorb.modules import PointNetClsHead >>> from aidsorb.models import PointNet >>> from torch.nn import MSELoss >>> from torchmetrics import MetricCollection, R2Score, MeanAbsoluteError as MAE
>>> model = PointNet(head=PointNetClsHead(n_outputs=10)) >>> loss, metric = MSELoss(), MetricCollection(R2Score(), MAE()) >>> litmodel = PointLit(model=model, loss=loss, metric=metric)
>>> x = torch.randn(32, 5, 100) >>> out = litmodel(x) >>> out.shape torch.Size([32, 10])
- test_step(batch, batch_idx)[source]
Make predictions on a single
batchfrom the test set for epoch-level operations.
- training_step(batch, batch_idx)[source]
Compute and return training loss on a single
batchfrom the train set.Also, make predictions that will be used on epoch-level operations.
Note
Inference mode is enabled during predictions, so an accurate estimate of training performance (e.g. when using
Dropout) is reported.