Source code for prolint2.core.contact_provider
r""":mod:`prolint2.core.contact_provider`
==========================================================
:Authors: Daniel P. Ramirez & Besian I. Sejdiu
:Year: 2022
:Copyright: MIT License
"""
from collections import defaultdict
from typing import Callable, Literal
import numpy as np
import pandas as pd
from prolint2.computers.contacts import ContactComputerBase, SerialContacts
from prolint2.core.typing import (
NestedFloatDict,
NestedIterFloatDict,
NestedIterIntDict,
LipidId,
)
from prolint2.metrics.base import BaseContactStore
from prolint2.metrics.exact_contacts import ExactContacts
from prolint2.metrics.aprox_contacts import AproxContacts
from prolint2.config.units import DEFAULT_SIM_PARAMS
[docs]class ComputedContacts:
"""A class to compute contacts between residues and lipids.
Parameters
----------
contact_strategy_instance : BaseContactStore
An instance of a contact strategy class.
provider : ContactsProvider
The contact provider that will be used to compute contacts.
"""
def __init__(
self, contact_strategy_instance: BaseContactStore, provider: "ContactsProvider"
):
self._contact_strategy = contact_strategy_instance
self.provider = provider
[docs] def compute_metric(self, metric: str, target_lipid_name=None) -> NestedFloatDict:
"""Compute a pre-defined metric for all lipids or a specific lipid.
Parameters
----------
metric : str
The metric to compute. Must be one of 'max', 'sum', 'mean'.
target_lipid_name : str, optional
The name of the lipid to compute the metric for. If None, the metric will be computed for all lipids.
Returns
-------
Dict[str, Dict[str, Dict[int, float]]]
A dictionary of computed metrics for all lipids.
Examples
--------
>>> c.compute('max')
>>> c.compute('sum', 'DOPC')
>>> c.compute('median') # raises ValueError. Use `apply_function` instead.
"""
return self._contact_strategy.compute(
metric, target_lipid_name=target_lipid_name
)
[docs] def apply_function(self, func: Callable, target_lipid_name=None) -> NestedFloatDict:
"""Apply the given function to the contacts for the given lipid name."""
return self._contact_strategy.apply_function(
func, target_lipid_name=target_lipid_name
)
@property
def contacts(self) -> NestedIterFloatDict:
"""The computed contacts."""
return self._contact_strategy.contacts
@property
def pooled_contacts(self) -> NestedIterFloatDict:
"""The computed contacts."""
return self._contact_strategy.pooled_results()
@property
def contact_frames(self) -> NestedIterIntDict:
"""The computed contacts."""
return self._contact_strategy.contact_frames
[docs] def create_dataframe(self, n_frames: int) -> pd.DataFrame:
"""Create a pandas DataFrame from the computed contacts.
Parameters
----------
n_frames : int
The number of frames in the trajectory.
Returns
-------
pd.DataFrame
A pandas DataFrame with the computed contacts.
"""
keys = []
contact_arrays = []
for residue_id, lipid_name_dict in self.contact_frames.items():
for lipid_id, frame_indices in lipid_name_dict.items():
contact_array = np.zeros(n_frames, dtype=np.int8)
contact_array[frame_indices] = 1
keys.append((residue_id, lipid_id))
contact_arrays.append(contact_array)
df = pd.DataFrame(
contact_arrays,
index=pd.MultiIndex.from_tuples(keys, names=["ResidueID", "LipidId"]),
)
df = df.sort_index(level=["ResidueID", "LipidId"], ascending=[True, True])
return df
[docs] def get_lipids_by_residue_id(self, residue_id: int) -> list:
"""Get all LipidIds that interact with the given ResidueID."""
return sorted(list(self.contact_frames[residue_id].keys()))
[docs] def get_residues_by_lipid_id(self, lipid_id: int) -> list:
"""Get all ResidueIDs that interact with the given LipidId."""
residues = [
residue_id
for residue_id, lipid_name_dict in self.contact_frames.items()
if lipid_id in lipid_name_dict.keys()
]
return residues
[docs] def get_contact_data(
self, residue_id: int, lipid_id: int, output: str = "contacts"
) -> list:
"""Get the contact data for a given residue and lipid.
Parameters
----------
residue_id : int
The residue id.
lipid_id : int
The lipid id.
output : str, optional
The output format. Must be one of 'contacts' or 'indices'.
Returns
-------
list
A list of contacts or frame indices.
"""
frame_indices = self.contact_frames[residue_id][lipid_id]
if output == "indices":
return frame_indices
else:
n_frames = (
max(
[
max(frame_indices_list)
for frame_indices_list in self.contact_frames[
residue_id
].values()
]
)
+ 1
)
contact_array = [1 if i in frame_indices else 0 for i in range(n_frames)]
return contact_array
[docs] def intersection(self, other: "ComputedContacts") -> "ComputedContacts":
"""Compute the intersection of two contact providers. Note that ProLint contacts use a radial cutoff.
This means that the intersection between two contact providers (c1 and c2) will be equal to the contact provider
with the smallest cutoff. ProLint, however, defines the intersection between two contact providers (c1 and c2) to
be equal to the lipid ids of the contact provider with the smallest cutoff, and the frame indices of the contact
provider with the largest cutoff. This way the intersection between two contact providers is meaningful and
computationaly allows for chaining of contact providers (See example below).
Parameters
----------
other : ComputedContacts
The other contact provider to compute the intersection with.
Returns
-------
ContactsProvider
A new contact provider with the intersection of the contacts of both contact providers.
Examples
--------
>>> ts = Universe('coordinates.gro', 'trajectory.xtc')
>>> c1 = ts.compute_contacts(cutoff=7)
>>> c2 = ts.compute_contacts(cutoff=8)
>>> c3 = c1 + c2
>>> c1 + c2 == c2 + c1 # True
"""
result_data = defaultdict(lambda: defaultdict(list))
for residue_id, lipid_ids in self.contact_frames.items():
for lipid_id in lipid_ids:
if LipidId(lipid_id) in other.contact_frames[residue_id]:
result_data[residue_id][lipid_id] = other.contact_frames[
residue_id
][lipid_id]
# Create a new instance of the contact strategy class
contact_instances = self._contact_strategy.__class__(
self.provider.query.universe, result_data
)
contact_instances.norm_factor = self.provider.params.get("norm_factor", 1)
contact_instances.run()
return ComputedContacts(contact_instances, self.provider)
[docs] def difference(self, other: "ComputedContacts") -> "ComputedContacts":
"""Compute the difference of two contact providers. Given two contact providers (c1 and c2), the difference
between them (c2 -c1) is defined as the contacts of c2 that are not present in c1.
Parameters
----------
other : ComputedContacts
The other contact provider to compute the difference with.
Returns
-------
ContactsProvider
A new contact provider with the difference of the contacts of both contact providers.
Examples
--------
>>> ts = Universe('coordinates.gro', 'trajectory.xtc')
>>> c1 = ts.compute_contacts(cutoff=7)
>>> c2 = ts.compute_contacts(cutoff=8)
>>> c3 = c2 - c1
>>> c1 - c2 == c2 - c1 # False, c1 - c2 will be an empty contact provider if c1 is a subset of c2
"""
result_data = defaultdict(lambda: defaultdict(list))
for residue_id, lipid_ids in self.contact_frames.items():
for lipid_id in lipid_ids:
if LipidId(lipid_id) not in other.contact_frames[residue_id]:
result_data[residue_id][lipid_id] = self.contact_frames[residue_id][
lipid_id
]
# Create a new instance of the contact strategy class
contact_instances = self._contact_strategy.__class__(
self.provider.query.universe, result_data
)
contact_instances.run()
return ComputedContacts(contact_instances, self.provider)
def __add__(self, other: "ComputedContacts") -> "ComputedContacts":
return self.intersection(other)
def __sub__(self, other: "ComputedContacts") -> "ComputedContacts":
return self.difference(other)
[docs]class ContactsProvider:
"""
Class that provides the contacts computation functionality.
"""
def __init__(
self,
query,
database,
params=None,
compute_strategy: Literal["default"] = "default",
contact_strategy: Literal["exact", "aprox"] = "exact",
):
self.query = query
self.database = database
self._contact_computers = {"default": SerialContacts}
self._contact_counter = {"exact": ExactContacts, "aprox": AproxContacts}
self._compute_strategy = compute_strategy
self._contact_strategy = self._contact_counter[contact_strategy]
self.params = params if params is not None else DEFAULT_SIM_PARAMS
[docs] def compute(
self, strategy_or_computer=None, start=None, stop=None, step=1, **kwargs
):
"""
Compute contacts between the query and the database.
Parameters
----------
strategy_or_computer : str or ContactComputerBase, optional
The strategy to compute contacts. If None, the default strategy is used.
**kwargs
Additional arguments to pass to the contact computer.
Returns
-------
ComputedContacts
The computed contacts.
"""
if strategy_or_computer is None:
strategy_or_computer = self._compute_strategy
# Strategy to compute contacts (e.g. serial, parallel, etc.)
if isinstance(strategy_or_computer, ContactComputerBase):
contact_computer = strategy_or_computer
else:
contact_computer_class = self._contact_computers.get(
strategy_or_computer, None
)
if contact_computer_class is None:
strats = ", ".join(self._contact_computers.keys())
raise ValueError(
f"Unknown strategy or computer: {strategy_or_computer}. Available strategies are: {strats}."
)
contact_computer = contact_computer_class(
self.query.universe, self.query, self.database, **kwargs
)
contact_computer.run(verbose=True, start=start, stop=stop, step=step)
# Strategy to count and store contacts (e.g. exact, aprox, etc.)
contact_strategy_instance = self._contact_strategy(
self.query.universe,
contact_computer.contact_frames,
self.params.get("norm_factor"),
)
contact_strategy_instance.run()
return ComputedContacts(contact_strategy_instance, self)
[docs] def load_from_file(self, file, **kwargs):
"""
Load contacts from a file.
Parameters
----------
file : str or pathlib.Path
The path to the file to load the contacts from.
**kwargs
Additional arguments to pass to the contact loader.
Returns
-------
ComputedContacts
The computed contacts.
"""
# get contact frames from file
df = pd.read_csv(file, index_col=[0, 1])
contact_frames = defaultdict(lambda: defaultdict(list))
for residue_id, lipid_id in df.index:
contact_frames[residue_id][lipid_id] = np.nonzero(
df.loc[(residue_id, lipid_id)].to_numpy()
)[0].tolist()
# Count and store contacts
contact_strategy_instance = self._contact_strategy(
self.query.universe, contact_frames, self.params.get("norm_factor")
)
contact_strategy_instance.run()
return ComputedContacts(contact_strategy_instance, self)