Source code for aidsorb.visualize

# This file is part of AIdsorb.
# Copyright (C) 2024 Antonios P. Sarikas

# AIdsorb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

r"""
Helper functions for visualizing point clouds.

.. tip::

    To visualize a point cloud from the CLI:

        .. code-block:: console
            
            $ aidsorb visualize path/to/structure_or_pcd  # Structure (.xyz, .cif, etc) or .npy

    You can also visualize a structure with :mod:`ase`:

        .. code-block:: python

            from ase.io import read
            from ase.visualize import view

            atoms = read('path/to/file')
            view(atoms)
"""

import numpy as np
from numpy.typing import NDArray, ArrayLike
from plotly.graph_objects import Figure, Scatter3d

from ._internal import check_shape, ptable
from .utils import pcd_from_file


[docs] def get_atom_colors(atomic_numbers: ArrayLike, scheme: str = 'cpk') -> NDArray: r""" Convert atomic numbers to colors based on ``scheme``. Parameters ---------- atomic_numbers : array-like of shape (N,) scheme : {'jmol', 'cpk'}, default='jmol' Returns ------- colors : array of shape (N,) """ atomic_numbers = np.array(atomic_numbers) scheme += '_color' return ptable.loc[atomic_numbers, scheme].to_numpy()
[docs] def get_atom_names(atomic_numbers: ArrayLike) -> NDArray: r""" Convert atomic numbers to element names. Parameters ---------- atomic_numbers : array-like of shape (N,) Returns ------- elements : array of shape (N,) Examples -------- >>> atomic_numbers = np.array([1, 2, 7]) >>> get_atom_names(atomic_numbers) array(['Hydrogen', 'Helium', 'Nitrogen'], dtype=object) """ atomic_numbers = np.array(atomic_numbers) return ptable.loc[atomic_numbers, 'name'].to_numpy()
[docs] def draw_pcd( pcd: NDArray, molecular: bool = True, scheme: str = 'cpk', size: float = 2., feature_to_color: tuple[int, str] | None = None, colorscale: str | None = None, ) -> Figure: r""" Visualize point cloud with Plotly. .. _colorscale: https://plotly.com/python/builtin-colorscales/ Parameters ---------- pcd : array of shape (N, 3+C) molecular : bool, default=True If :obj:`True`, assume a molecular point cloud. In this case, each atom is sized as :math:`r_{\text{vdW}}^4` and colorized based on ``scheme``. If :obj:`False`, assume a generic point cloud and size each point based on ``size``. scheme : {'jmol', 'cpk'}, default='jmol' Takes effect only if ``molecular=True`` and ``feature_to_color=None``. size : float, default=2. Controls the size of points. feature_to_color : tuple, optional Tuple of the form ``(index, label)``, where ``index`` is the index of the feature to be colored and ``label`` is the text label for the colorbar. colorscale : str, optional No effect if ``feature_to_color=None``. For available options, see `colorscale`_. Returns ------- plotly.graph_objects.Figure Examples -------- >>> pcd = np.random.randn(10, 3) >>> fig = draw_pcd(pcd, molecular=False, feature_to_color=(0, 'x coord'), colorscale='viridis') """ check_shape(pcd) size = (ptable.loc[pcd[:, 3], 'vdw_radius'] * 0.01)**4 if molecular else size hovertext = get_atom_names(pcd[:, 3]) if molecular else None color = get_atom_colors(pcd[:, 3], scheme=scheme) if molecular else None marker = {'size': size, 'color': color} if feature_to_color is not None: idx, label = feature_to_color marker.update({ 'color': pcd[:, idx], 'colorscale': colorscale, 'colorbar': {'thickness': 20, 'title': label}, }) fig = Figure( data=[Scatter3d( x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], mode='markers', marker=marker, hovertext=hovertext, )], ) return fig
[docs] def draw_pcd_from_file( filename: str, render: bool = True, **kwargs ) -> Figure | None: r""" Visualize point cloud from a file. Parameters ---------- filename : str Absolute or relative path to a ``.npy`` or structure file. render : bool, default=True Whether to render the point cloud with :data:`plotly.io.renderers.default` or return the figure object. **kwargs Valid keyword arguments for :func:`draw_pcd`. Returns ------- plotly.graph_objects.Figure or None """ if filename.endswith('.npy'): pcd = np.load(filename) else: _, pcd = pcd_from_file(filename) fig = draw_pcd(pcd, **kwargs) return fig.show() if render else fig