Source code for prolint2.metrics.exact_contacts

r""":mod:`prolint2.metrics.exact_contacts`
==========================================================
:Authors: Daniel P. Ramirez & Besian I. Sejdiu
:Year: 2022
:Copyright: MIT License
"""

from typing import List, Dict, Callable, Union

from collections import defaultdict

import numpy as np

from prolint2.metrics.base import BaseContactStore
from prolint2.metrics.utils import (
    fast_filter_resids_by_resname,
    fast_contiguous_segment_lengths,
)


[docs]class ExactContacts(BaseContactStore): """Compute the duration of lipid contacts. This class is used to compute the duration of lipid contacts."""
[docs] def run(self, lipid_resnames: Union[str, List] = None) -> Dict[str, np.ndarray]: """Compute the duration of lipid contacts for all lipid types. Parameters ---------- lipid_resnames : str, optional A list of lipid residue names to compute durations for. If None, durations will be computed for all lipid types. Returns ------- Dict[str, np.ndarray] A dictionary of lipid contact durations for all lipid types. The output is stored in the `self._contacts` attribute. """ if lipid_resnames is None: lipid_resnames = self._database_unique_resnames elif isinstance(lipid_resnames, str): lipid_resnames = [lipid_resnames] for residue, contact_frame in self.contact_frames.items(): for lipid_resname in lipid_resnames: result = self.compute_lipid_durations(contact_frame, lipid_resname) if len(result) > 0: self._contacts[residue][lipid_resname] = result
[docs] def pooled_results(self, target_lipid_name=None): """Pool results for all lipids. Parameters ---------- target_lipid_name : str, optional The name of the lipid to compute pooled results for. If None, pooled results will be computed for all lipids. Returns ------- Dict[str, Dict[str, List[float]]] A dictionary of pooled results for all lipids. """ pooled_results = defaultdict(lambda: defaultdict(list)) for residue, lipid_data in self._contacts.items(): for lipid_name, lipid_contacts in lipid_data.items(): if target_lipid_name is None or lipid_name == target_lipid_name: pooled_contact_array = [] for lipid_id_contacts in lipid_contacts.values(): pooled_contact_array.extend(lipid_id_contacts) pooled_results[residue][lipid_name].extend(pooled_contact_array) return pooled_results
[docs] def compute_metric(self, metric: str, target_lipid_name=None): """Compute a pre-defined metric for all lipids or a specific lipid. Parameters ---------- metric : str The metric to compute. Must be one of 'max', 'sum', 'mean'. target_lipid_name : str, optional The name of the lipid to compute the metric for. If None, the metric will be computed for all lipids. Returns ------- Dict[str, Dict[str, Dict[int, float]]] A dictionary of computed metrics for all lipids. """ computed_results = defaultdict(lambda: defaultdict(dict)) for residue, lipid_data in self._contacts.items(): # computed_results[residue] = {} for lipid_name, lipid_contacts in lipid_data.items(): if target_lipid_name is None or lipid_name == target_lipid_name: computed_contacts_per_id = { lipid_id: getattr(np, metric)(contact_array) for lipid_id, contact_array in lipid_contacts.items() } computed_results[residue][lipid_name] = computed_contacts_per_id return computed_results
[docs] def apply_function(self, func: Callable, target_lipid_name=None): """Apply a function to all lipids or a specific lipid. Parameters ---------- func : Callable The function to apply to the lipid contact durations. target_lipid_name : str, optional The name of the lipid to apply the function to. If None, the function will be applied to all lipids. Returns ------- Dict[str, Dict[str, Dict[int, float]]] A dictionary of computed metrics for all lipids. Example ------- >>> cd = ExactContacts(...) >>> cd.run() >>> cd.apply_function(np.mean) >>> cd.apply_function(np.max, target_lipid_name='DOPC') >>> cd.apply_function(lambda x: np.mean(x) / np.max(x), target_lipid_name='DOPC') """ computed_results = {} for residue, lipid_data in self._contacts.items(): computed_results[residue] = {} for lipid_name, lipid_contacts in lipid_data.items(): if target_lipid_name is None or lipid_name == target_lipid_name: computed_contacts_per_id = { lipid_id: func(contact_array) for lipid_id, contact_array in lipid_contacts.items() } computed_results[residue][lipid_name] = computed_contacts_per_id return computed_results
[docs] def compute_lipid_durations( self, contact_frame: Dict[int, List[int]], lipid_resname: str ) -> np.ndarray: """Compute the duration of lipid contacts. Parameters ---------- contact_frame : Dict[int, List[int]] A dictionary of contact frames. lipid_resname : str The residue name of the lipid to compute durations for. Returns ------- np.ndarray An array of lipid contact durations. """ ids_to_filter = np.array(list(contact_frame.keys())) lipid_ids = fast_filter_resids_by_resname( self._resids, self._resnames, ids_to_filter, lipid_resname ) durations = {} for k, arr in contact_frame.items(): if k in lipid_ids: durations[k] = fast_contiguous_segment_lengths(arr, self.norm_factor) return durations