.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/resume.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_resume.py: Coming back after model training ================================ .. GENERATED FROM PYTHON SOURCE LINES 7-17 After training a model, you might want to test its performance, make predictions or do whatever you want with it. .. note:: This example assummes: * `PyTorch Lightning checkpoints `_ are enabled during training. * Training was performed with AIdsorb :doc:`../cli` or :ref:`AIdsorb + PyTorch Lightning `. .. GENERATED FROM PYTHON SOURCE LINES 17-26 .. code-block:: Python import yaml import torch import lightning as L from lightning.pytorch.cli import LightningArgumentParser from torch.utils.data import DataLoader from aidsorb.datamodules import PCDDataModule from aidsorb.litmodels import PointLit .. GENERATED FROM PYTHON SOURCE LINES 27-36 The following snipper let us instantiate: * Trainer * LightningModule (litmodel) * Datamodule with the same settings as in the ``.yaml`` configuration file. For more information 👉 `here `_. .. GENERATED FROM PYTHON SOURCE LINES 38-40 .. note:: You are responsible for restoring the model's state (the weights of the model). .. GENERATED FROM PYTHON SOURCE LINES 40-60 .. code-block:: Python with open(filename, 'r') as f: config_dict = yaml.safe_load(f) # They are not needed during inference. config_dict['trainer']['logger'] = False del config_dict['seed_everything'], config_dict['ckpt_path'] parser = LightningArgumentParser() parser.add_lightning_class_args(PointLit, 'model') parser.add_lightning_class_args(PCDDataModule, 'data') parser.add_class_arguments(L.Trainer, 'trainer') # Any other key present in the config file must also be added. # parser.add_argument(--, ...) # For more information 👉 https://jsonargparse.readthedocs.io/en/stable/#parsers config = parser.parse_object(config_dict) objects = parser.instantiate_classes(config) .. GENERATED FROM PYTHON SOURCE LINES 61-64 .. code-block:: Python trainer, litmodel, dm = objects.trainer, objects.model, objects.data .. GENERATED FROM PYTHON SOURCE LINES 67-69 Restoring model's state ----------------------- .. GENERATED FROM PYTHON SOURCE LINES 69-73 .. code-block:: Python # Load the the checkpoint. ckpt = torch.load('path/to/checkpoints/checkpoint.ckpt') .. GENERATED FROM PYTHON SOURCE LINES 74-78 .. code-block:: Python # Load back the weights. litmodel.load_state_dict(ckpt['state_dict']) .. GENERATED FROM PYTHON SOURCE LINES 79-88 .. code-block:: Python # Set the model for inference (disable grads & enable eval mode). litmodel.freeze() print(f'Model in evaluation mode: {not litmodel.training}') # Your code goes here. ... .. GENERATED FROM PYTHON SOURCE LINES 89-91 Measure performance ------------------- .. GENERATED FROM PYTHON SOURCE LINES 91-95 .. code-block:: Python # Measure performance on test set. trainer.test(litmodel, dm) .. GENERATED FROM PYTHON SOURCE LINES 96-98 Make predictions ---------------- .. GENERATED FROM PYTHON SOURCE LINES 98-107 .. code-block:: Python # Setup the datamodule. dm.setup() # Predict on the test set. y_pred = torch.cat(trainer.predict(litmodel, dm.test_dataloader())) # Predict on the train set. y_pred = torch.cat(trainer.predict(litmodel, dm.train_dataloader())) .. _sphx_glr_download_auto_examples_resume.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: resume.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: resume.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_