Source code for prolint2.metrics.base

r""":mod:`prolint2.metrics.base`
==========================================================
:Authors: Daniel P. Ramirez & Besian I. Sejdiu
:Year: 2022
:Copyright: MIT License
"""

from abc import ABC, abstractmethod
from collections import defaultdict

from typing import Type, List, Union, Callable

from scipy.optimize import curve_fit
from prolint2.metrics.formatters import OutputFormat, DefaultOutputFormat


MetricRegistry = Type["registries.MetricRegistry"]


[docs]class BaseMetric(ABC): """Base class for all metrics classes that act on single frame contact Iterables.""" name: str = None def __init__(self): pass @abstractmethod def compute_metric(self, contact_array): pass @classmethod def _register(cls, registry: MetricRegistry): registry.register(cls.name, cls)
[docs]class Metric(ABC): """Base class for metric calculation.""" def __init__( self, contacts, metrics, output_format: OutputFormat = DefaultOutputFormat(), lipid_type=None, clear=True, ): self.contact_input = dict(sorted(contacts.contacts.items())) if not isinstance(metrics, list): metrics = [metrics] self.metrics = metrics if clear: output_format.clear() self.output_format = output_format self.lipid_type = lipid_type
[docs] def compute(self, dt=1, totaltime=1): """Compute the metric for the given contacts.""" multiplier = dt / totaltime for residue_id, lipid_dict in self.contact_input.items(): for lipid_name, contact_array in lipid_dict.items(): if self.lipid_type is not None and self.lipid_type != lipid_name: continue # contact_array = list(lipid_contacts.values()) if contact_array: for metric in self.metrics: # if max(contact_array) > 1: # print ('contact_array', residue_id, lipid_name, max(contact_array)) value = metric.compute_metric(contact_array) * multiplier # print ('value', residue_id, lipid_name, value, multiplier) self.output_format.store_result( residue_id, lipid_name, metric.__class__.__name__, value ) else: for metric in self.metrics: self.output_format.store_result( residue_id, lipid_name, metric.__class__.__name__, 0 ) return self.output_format.get_result()
[docs]class BaseContactStore: """Base class for storing contact.""" def __init__(self, ts, contact_frames, norm_factor: float = 1.0): self.norm_factor = float(norm_factor) self.contact_frames = contact_frames self._resids = ts.database.residues.resids self._resnames = ts.database.residues.resnames self._database_unique_resnames = ts.database.unique_resnames self._contacts = defaultdict(lambda: defaultdict(dict))
[docs] def run(self, lipid_resnames: Union[str, List] = None): """Run the contact calculation for the given lipid resnames. If no resnames are given, all resnames are used.""" raise NotImplementedError("Subclasses should implement this method.")
[docs] def compute(self, metric: str, target_lipid_name=None): """Compute a pre-defined metric for all lipids or a specific lipid. Parameters ---------- metric : str The metric to compute. Must be one of 'max', 'sum', 'mean'. target_lipid_name : str, optional The name of the lipid to compute the metric for. If None, the metric will be computed for all lipids. Returns ------- Dict[str, Dict[str, Dict[int, float]]] A dictionary of computed metrics for all lipids. Examples -------- >>> cd = AproxContacts(...) >>> cd.run() >>> cd.compute('max') >>> cd.compute('sum', 'DOPC') >>> cd.compute('median') # raises ValueError. Use `apply_function` instead. """ if metric in ["max", "sum", "mean"]: return self.compute_metric(metric, target_lipid_name) else: raise ValueError( "Invalid metric specified. Use 'max', 'sum', 'mean'. For more complex metrics, use `apply_function`." )
[docs] def compute_metric(self, metric: str, target_lipid_name=None): """Compute the given metric for the given lipid name.""" raise NotImplementedError("Subclasses should implement this method.")
[docs] def apply_function(self, func: Callable, target_lipid_name=None): """Apply the given function to the contacts for the given lipid name.""" raise NotImplementedError("Subclasses should implement this method.")
[docs] def pooled_results(self): """Get the computed contacts all pooled together.""" raise NotImplementedError("Subclasses should implement this method.")
@property def results(self): """Get the computed contacts per lipid id.""" if self._contacts is None: raise ValueError("No contacts have been computed yet. Call run() first.") return self._contacts @property def contacts(self): """Get the computed contacts all pooled together.""" if self._contacts is None: raise ValueError("No contacts have been computed yet. Call run() first.") return self.pooled_results()
[docs]class FittingFunctionMeta(type): """Metaclass for fitting functions.""" def __init__(cls, name, bases, dct): if not hasattr(cls, "registry"): cls.registry = {} else: cls.registry[cls.name] = cls super().__init__(name, bases, dct)
[docs]class FittingFunction(metaclass=FittingFunctionMeta): """Base class for fitting functions.""" name = None p0 = [1, 1, 1, 1] maxfev = 1000000 def compute(self, x, *params): raise NotImplementedError("Subclasses must implement this method") def get_koff(self, popt): raise NotImplementedError("Subclasses must implement this method") def fit(self, x_data, y_data, **kwargs): if "p0" not in kwargs: kwargs["p0"] = self.p0 if "maxfev" not in kwargs: kwargs["maxfev"] = self.maxfev popt, _ = curve_fit(self.compute, x_data, y_data, **kwargs) return popt