Source code for prolint.plotting.density

"""Density map plotters for spatial distribution visualization.

This module provides 2D density and radial density visualizations
for database molecule distributions around query atoms.
"""

from typing import Optional, Tuple, List
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.colors import LogNorm, Normalize

from prolint.analysis.base import AnalysisResult
from prolint.plotting.base import BasePlotter, PlottingRegistry
from prolint.plotting.theme import COLORS, get_colormap, apply_prolint_style


[docs] class DensityMapPlotter(BasePlotter): """Plotter for 2D spatial density maps. Visualizes the spatial distribution of database molecules around the query center of mass. See Also -------- DensityMapAnalysis : Generates density map data RadialDensityPlotter : Radial density profiles """
[docs] name = "density_map"
[docs] required_analysis = "density_map"
[docs] description = "2D spatial density of database molecules around query"
@classmethod
[docs] def validate_result(cls, result: AnalysisResult) -> None: """Validate that result contains required density map data.""" required_keys = ["density", "x_edges", "y_edges"] missing = [k for k in required_keys if k not in result.data] if missing: raise ValueError( f"AnalysisResult missing required keys for {cls.name}: {missing}. " f"Expected result from '{cls.required_analysis}' analysis." )
@classmethod
[docs] def plot( cls, result: AnalysisResult, colorscheme: str = "viridis", xlabel: str = "X (Å)", ylabel: str = "Y (Å)", title: str = "Density Map", figsize: Tuple[float, float] = (8, 8), ax: Optional[Axes] = None, show_colorbar: bool = True, cbar_label: str = "Density", log_scale: bool = False, aspect: str = "equal", vmin: Optional[float] = None, vmax: Optional[float] = None, show_query_contours: bool = True, highlight_query_residues: Optional[List[int]] = None, highlight_database_ids: Optional[List[int]] = None, universe=None, frame: int = 0, ) -> Tuple[Figure, Axes]: """Create 2D density map visualization. Parameters ---------- result : AnalysisResult Result from density_map analysis. colorscheme : str, default="viridis" Color scale name. log_scale : bool, default=False Whether to use logarithmic color scale. show_query_contours : bool, default=True Whether to overlay query density as contours. highlight_query_residues : list of int, optional Query residues to highlight on the map. highlight_database_ids : list of int, optional Database molecules to highlight. universe : Universe, optional Required for highlighting residues. ax : Axes, optional Existing axes to plot on. Returns ------- tuple of (Figure, Axes) Matplotlib figure and axes objects. """ cls.validate_result(result) apply_prolint_style() # Extract data from result density = np.asarray(result.data["density"]) x_edges = result.data["x_edges"] y_edges = result.data["y_edges"] if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure cmap = get_colormap(colorscheme) x_min, x_max = x_edges[0], x_edges[-1] y_min, y_max = y_edges[0], y_edges[-1] if vmin is None: vmin = np.nanmin(density[density > 0]) if log_scale else np.nanmin(density) if vmax is None: vmax = np.nanmax(density) if log_scale: density_plot = np.where(density > 0, density, vmin) norm = LogNorm(vmin=vmin, vmax=vmax) else: density_plot = density norm = Normalize(vmin=vmin, vmax=vmax) im = ax.imshow( density_plot.T, origin="lower", extent=[x_min, x_max, y_min, y_max], cmap=cmap, norm=norm, aspect=aspect, interpolation="bilinear", ) ax.set_xlabel(xlabel, fontsize=11) ax.set_ylabel(ylabel, fontsize=11) ax.set_title(title, fontsize=12, fontweight="semibold") if show_colorbar: cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) cbar.set_label(cbar_label, fontsize=10) if log_scale: cbar.ax.set_ylabel(f"{cbar_label} (log scale)", fontsize=10) # Overlay query density as contours using pre-computed centers if show_query_contours and "query_density" in result.data: query_density = np.asarray(result.data["query_density"]) x_centers = result.data["x_centers"] y_centers = result.data["y_centers"] X, Y = np.meshgrid(x_centers, y_centers) ax.contour( X, Y, query_density.T, levels=5, colors=COLORS["data"]["highlight"], linewidths=1.5, alpha=0.8, ) ax.annotate( "Query (contours)", xy=(0.02, 0.98), xycoords="axes fraction", fontsize=9, color=COLORS["data"]["highlight"], va="top", ) # Highlight specific residues if requested if universe is not None and ( highlight_query_residues or highlight_database_ids ): universe.trajectory[frame] query_com = universe.query.atoms.center_of_mass()[:2] legend_handles = [] if highlight_query_residues: for resid in highlight_query_residues: residue_atoms = universe.query.select_atoms(f"resid {resid}") if len(residue_atoms) > 0: pos = residue_atoms.center_of_mass()[:2] - query_com scatter = ax.scatter( pos[0], pos[1], s=150, c=COLORS["data"]["query"], edgecolors="white", linewidths=2, marker="o", zorder=10, ) ax.annotate( str(resid), (pos[0], pos[1]), xytext=(5, 5), textcoords="offset points", fontsize=9, fontweight="bold", color=COLORS["data"]["query"], ) if highlight_query_residues: legend_handles.append( plt.Line2D( [0], [0], marker="o", color="w", markerfacecolor=COLORS["data"]["query"], markersize=10, label="Query residues", ) ) if highlight_database_ids: for resid in highlight_database_ids: residue_atoms = universe.database.select_atoms(f"resid {resid}") if len(residue_atoms) > 0: pos = residue_atoms.center_of_mass()[:2] - query_com scatter = ax.scatter( pos[0], pos[1], s=150, c=COLORS["data"]["database"], edgecolors="white", linewidths=2, marker="s", zorder=10, ) ax.annotate( str(resid), (pos[0], pos[1]), xytext=(5, 5), textcoords="offset points", fontsize=9, fontweight="bold", color=COLORS["data"]["database"], ) if highlight_database_ids: legend_handles.append( plt.Line2D( [0], [0], marker="s", color="w", markerfacecolor=COLORS["data"]["database"], markersize=10, label="Database residues", ) ) if legend_handles: ax.legend(handles=legend_handles, loc="lower right", fontsize=9) plt.tight_layout() return fig, ax
[docs] class RadialDensityPlotter(BasePlotter): """Plotter for radial density profiles. Visualizes radially-averaged density as a function of distance from the query center. See Also -------- RadialDensityAnalysis : Generates radial density data DensityMapPlotter : 2D density maps """
[docs] name = "radial_density"
[docs] required_analysis = "radial_density"
[docs] description = "Radial density profile from 2D density map"
@classmethod
[docs] def validate_result(cls, result: AnalysisResult) -> None: """Validate that result contains required radial density data.""" required_keys = ["r_centers", "radial_density"] missing = [k for k in required_keys if k not in result.data] if missing: raise ValueError( f"AnalysisResult missing required keys for {cls.name}: {missing}. " f"Expected result from '{cls.required_analysis}' analysis." )
@classmethod
[docs] def plot( cls, result: AnalysisResult, xlabel: str = "Distance from Center (Å)", ylabel: str = "Radial Density", title: str = "Radial Density Profile", figsize: Tuple[float, float] = (8, 4), ax: Optional[Axes] = None, ) -> Tuple[Figure, Axes]: """Create radial density profile plot. Parameters ---------- result : AnalysisResult Result from radial_density analysis. ax : Axes, optional Existing axes to plot on. Returns ------- tuple of (Figure, Axes) Matplotlib figure and axes objects. """ cls.validate_result(result) apply_prolint_style() # Extract data from result r_centers = result.data["r_centers"] radial_density = result.data["radial_density"] r_max = result.data.get("r_max", max(r_centers) if r_centers else 1.0) if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure ax.plot(r_centers, radial_density, color=COLORS["data"]["query"], linewidth=2) ax.fill_between( r_centers, radial_density, alpha=0.3, color=COLORS["data"]["query"] ) ax.set_xlabel(xlabel, fontsize=11) ax.set_ylabel(ylabel, fontsize=11) ax.set_title(title, fontsize=12, fontweight="semibold") ax.set_xlim(0, r_max) ax.set_ylim(0, None) ax.grid(True, alpha=0.3, linestyle="--") plt.tight_layout() return fig, ax
# Register plotters PlottingRegistry.register("density_map", DensityMapPlotter) PlottingRegistry.register("radial_density", RadialDensityPlotter)