Source code for prolint.plotting.kinetics

"""Kinetics plotters for binding dynamics visualization.

This module provides survival curves and residence time distributions
for contact kinetics analysis results.
"""

from typing import Optional, Tuple
import warnings
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes

from prolint.analysis.base import AnalysisResult
from prolint.plotting.base import BasePlotter, PlottingRegistry
from prolint.plotting.theme import COLORS, apply_prolint_style, get_unit_label


[docs] class SurvivalCurvePlotter(BasePlotter): """Plotter for survival curves with exponential fits. Visualizes contact survival probability over lag time with optional mono- and bi-exponential model fits. See Also -------- KineticsAnalysis : Generates survival curve data ResidenceDistributionPlotter : Residence time histograms """
[docs] name = "survival_curve"
[docs] required_analysis = "kinetics"
[docs] description = "Survival curve with mono/bi-exponential fits"
@classmethod
[docs] def validate_result(cls, result: AnalysisResult) -> None: """Validate that result contains required survival curve data.""" if "survival_curve" not in result.data: raise ValueError( f"AnalysisResult missing 'survival_curve' key for {cls.name}. " f"Expected result from '{cls.required_analysis}' analysis." ) survival = result.data["survival_curve"] required_keys = ["lag_times", "survival_probability"] missing = [k for k in required_keys if k not in survival] if missing: raise ValueError( f"survival_curve missing required keys: {missing}. " f"Expected result from '{cls.required_analysis}' analysis." )
@classmethod
[docs] def plot( cls, result: AnalysisResult, xlabel: str = "Lag Time (frames)", ylabel: str = "Survival Probability", title: str = "Survival Curve", figsize: Tuple[float, float] = (8, 5), ax: Optional[Axes] = None, show_legend: bool = True, time_units: Optional[str] = None, dt: float = 1.0, ) -> Tuple[Figure, Axes]: """Create survival curve plot with exponential fits. Parameters ---------- result : AnalysisResult Result from kinetics analysis. time_units : str, optional Time unit for x-axis (e.g., "ns", "us"). dt : float, default=1.0 Time step multiplier when using time_units. show_legend : bool, default=True Whether to show fit parameters in legend. ax : Axes, optional Existing axes to plot on. Returns ------- tuple of (Figure, Axes) Matplotlib figure and axes objects. """ cls.validate_result(result) apply_prolint_style() # Extract data from result survival = result.data["survival_curve"] kinetics = result.data.get("kinetics", {}) n_events = kinetics.get("n_events", 0) min_events_mono = survival.get("min_events_mono", 5) # Check for insufficient data (same limit as frontend) if n_events < min_events_mono: warnings.warn( f"Insufficient data for survival curve plotting ({n_events} events, " f"need >= {min_events_mono}). Skipping plot.", UserWarning, ) if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure ax.text( 0.5, 0.5, f"Insufficient data\n({n_events} events, need >= {min_events_mono})", ha="center", va="center", transform=ax.transAxes, fontsize=11, color=COLORS["text"]["secondary"], ) ax.set_xlabel(xlabel, fontsize=11) ax.set_ylabel(ylabel, fontsize=11) ax.set_title(title, fontsize=12, fontweight="semibold") ax.set_xlim(0, 1) ax.set_ylim(0, 1) return fig, ax lag_times = survival["lag_times"] survival_probability = survival["survival_probability"] mono_fit = survival.get("mono_fit") bi_fit = survival.get("bi_fit") selected_model = survival.get("selected_model") if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure # Convert to time if specified if time_units: x_values = np.array(lag_times) * dt xlabel = f"Lag Time ({get_unit_label(time_units)})" else: x_values = np.array(lag_times) # Plot raw survival curve data points ax.scatter( x_values, survival_probability, color=COLORS["neutral"]["500"], s=30, alpha=0.6, label="Data", zorder=5, ) # Plot monoexponential fit if available if mono_fit is not None and "fitted_curve" in mono_fit: fit_y = mono_fit["fitted_curve"] is_selected = selected_model == "monoexponential" linewidth = 2.5 if is_selected else 1.5 alpha = 1.0 if is_selected else 0.5 linestyle = "-" if is_selected else "--" k_off = mono_fit.get("k_off", 0) r2 = mono_fit.get("r_squared", 0) half_life = mono_fit.get("half_life") label = f"Mono: k={k_off:.3f}, R²={r2:.3f}" if half_life: label += f", t½={half_life:.1f}" ax.plot( x_values, fit_y, color=COLORS["data"]["query"], linewidth=linewidth, alpha=alpha, linestyle=linestyle, label=label, ) # Plot biexponential fit if available if bi_fit is not None and "fitted_curve" in bi_fit: fit_y = bi_fit["fitted_curve"] is_selected = selected_model == "biexponential" linewidth = 2.5 if is_selected else 1.5 alpha = 1.0 if is_selected else 0.5 linestyle = "-" if is_selected else "--" k_fast = bi_fit.get("k_fast", 0) k_slow = bi_fit.get("k_slow", 0) r2 = bi_fit.get("r_squared", 0) label = f"Bi: k_fast={k_fast:.3f}, k_slow={k_slow:.3f}, R²={r2:.3f}" ax.plot( x_values, fit_y, color=COLORS["data"]["database"], linewidth=linewidth, alpha=alpha, linestyle=linestyle, label=label, ) ax.set_xlabel(xlabel, fontsize=11) ax.set_ylabel(ylabel, fontsize=11) ax.set_title(title, fontsize=12, fontweight="semibold") ax.set_ylim(0, 1.05) ax.set_xlim(x_values[0], x_values[-1]) ax.grid(True, alpha=0.3, linestyle="--") if show_legend: ax.legend(loc="upper right", fontsize=9, framealpha=0.9) plt.tight_layout() return fig, ax
[docs] class ResidenceDistributionPlotter(BasePlotter): """Plotter for residence time distributions. Visualizes the distribution of contact durations as a histogram. See Also -------- KineticsAnalysis : Generates residence distribution data SurvivalCurvePlotter : Survival curves """
[docs] name = "residence_distribution"
[docs] required_analysis = "kinetics"
[docs] description = "Histogram of residence time durations"
@classmethod
[docs] def validate_result(cls, result: AnalysisResult) -> None: """Validate that result contains required residence distribution data.""" if "residence_distribution" not in result.data: raise ValueError( f"AnalysisResult missing 'residence_distribution' key for {cls.name}. " f"Expected result from '{cls.required_analysis}' analysis." ) residence = result.data["residence_distribution"] required_keys = ["bins", "counts"] missing = [k for k in required_keys if k not in residence] if missing: raise ValueError( f"residence_distribution missing required keys: {missing}. " f"Expected result from '{cls.required_analysis}' analysis." )
@classmethod
[docs] def plot( cls, result: AnalysisResult, xlabel: str = "Residence Time (frames)", ylabel: str = "Count", title: str = "Residence Time Distribution", figsize: Tuple[float, float] = (8, 4), ax: Optional[Axes] = None, time_units: Optional[str] = None, dt: float = 1.0, log_scale: bool = False, ) -> Tuple[Figure, Axes]: """Create residence time histogram. Parameters ---------- result : AnalysisResult Result from kinetics analysis. time_units : str, optional Time unit for x-axis. dt : float, default=1.0 Time step multiplier. log_scale : bool, default=False Whether to use log scale for y-axis. ax : Axes, optional Existing axes to plot on. Returns ------- tuple of (Figure, Axes) Matplotlib figure and axes objects. """ cls.validate_result(result) apply_prolint_style() # Extract data from result residence = result.data["residence_distribution"] bins = residence["bins"] counts = residence["counts"] kinetics = result.data.get("kinetics", {}) mean_residence = kinetics.get("mean_residence_time", 0) n_events = kinetics.get("n_events", sum(counts)) if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure # Convert bins to time if specified if time_units: x_values = np.array(bins) * dt xlabel = f"Residence Time ({get_unit_label(time_units)})" else: x_values = np.array(bins) # Calculate bar width if len(x_values) > 1: width = (x_values[1] - x_values[0]) * 0.8 else: width = 0.8 ax.bar( x_values, counts, width=width, color=COLORS["data"]["query"], edgecolor=COLORS["data"]["query_dark"], alpha=0.8, ) ax.set_xlabel(xlabel, fontsize=11) ax.set_ylabel(ylabel, fontsize=11) ax.set_title(title, fontsize=12, fontweight="semibold") if log_scale and max(counts) > 0: ax.set_yscale("log") ax.set_xlim(0, x_values[-1] + width if len(x_values) > 0 else 10) ax.grid(True, alpha=0.3, linestyle="--", axis="y") # Add stats using pre-computed mean from kinetics stats_text = f"N={n_events}, Mean={mean_residence:.1f} frames" ax.text( 0.98, 0.98, stats_text, transform=ax.transAxes, fontsize=9, ha="right", va="top", color=COLORS["text"]["secondary"], ) plt.tight_layout() return fig, ax
# Register plotters PlottingRegistry.register("survival_curve", SurvivalCurvePlotter) PlottingRegistry.register("residence_distribution", ResidenceDistributionPlotter)