Note
Go to the end to download the full example code.
Coming back after model training
After training a model, you might want to test its performance, make predictions or do whatever you want with it.
Note
- This example assummes:
PyTorch Lightning checkpoints are enabled during training.
Training was performed with AIdsorb CLI or AIdsorb + PyTorch Lightning.
import yaml
import torch
import lightning as L
from lightning.pytorch.cli import LightningArgumentParser
from torch.utils.data import DataLoader
from aidsorb.datamodules import PCDDataModule
from aidsorb.litmodels import PointLit
The following snipper let us instantiate:
Trainer
LightningModule (litmodel)
Datamodule
with the same settings as in the .yaml configuration file. For more
information 👉 here.
Note
You are responsible for restoring the model’s state (the weights of the model).
with open(filename, 'r') as f:
config_dict = yaml.safe_load(f)
# They are not needed during inference.
config_dict['trainer']['logger'] = False
del config_dict['seed_everything'], config_dict['ckpt_path']
parser = LightningArgumentParser()
parser.add_lightning_class_args(PointLit, 'model')
parser.add_lightning_class_args(PCDDataModule, 'data')
parser.add_class_arguments(L.Trainer, 'trainer')
# Any other key present in the config file must also be added.
# parser.add_argument(--<keyname>, ...)
# For more information 👉 https://jsonargparse.readthedocs.io/en/stable/#parsers
config = parser.parse_object(config_dict)
objects = parser.instantiate_classes(config)
trainer, litmodel, dm = objects.trainer, objects.model, objects.data
Restoring model’s state
# Load the the checkpoint.
ckpt = torch.load('path/to/checkpoints/checkpoint.ckpt')
# Load back the weights.
litmodel.load_state_dict(ckpt['state_dict'])
# Set the model for inference (disable grads & enable eval mode).
litmodel.freeze()
print(f'Model in evaluation mode: {not litmodel.training}')
# Your code goes here.
...
Measure performance
# Measure performance on test set.
trainer.test(litmodel, dm)
Make predictions
# Setup the datamodule.
dm.setup()
# Predict on the test set.
y_pred = torch.cat(trainer.predict(litmodel, dm.test_dataloader()))
# Predict on the train set.
y_pred = torch.cat(trainer.predict(litmodel, dm.train_dataloader()))