Source code for prolint.analysis.shared_contacts
"""Shared contacts analysis for pairwise residue correlations."""
from typing import Optional, List, Dict, Set
from prolint.analysis.base import BaseAnalysis, AnalysisResult
[docs]
class SharedContactsAnalysis(BaseAnalysis):
"""Analyze pairwise correlations between query residues.
Identifies query residue pairs that interact with the same database
molecule simultaneously, revealing potential cooperative binding sites.
See Also
--------
TimeSeriesAnalysis : Contact dynamics over time
"""
"""Analysis name for registry."""
[docs]
description = (
"Pairwise correlations between query residues via shared database contacts"
)
"""Human-readable description."""
[docs]
def run(
self,
database_type: Optional[str] = None,
normalize: bool = False,
) -> AnalysisResult:
"""Compute shared contacts correlation matrix.
Parameters
----------
database_type : str, optional
Filter by database residue name (e.g., "CHOL").
If None, includes all database molecules.
normalize : bool, default=False
Whether to normalize the matrix to 0-1 range.
Returns
-------
AnalysisResult
Result with data containing:
- labels : list of int query residue IDs
- matrix : 2D list of int/float shared contact counts
- residue_to_index : dict mapping residue ID to matrix index
"""
# Filter contacts by database type and build inverted index
filtered_contacts = self._filter_by_database_type(database_type)
database_to_query_frames = self._build_inverted_index(filtered_contacts)
# Get all query residues from the inverted index
all_query_residues: Set[int] = set()
for query_frames_dict in database_to_query_frames.values():
all_query_residues.update(query_frames_dict.keys())
query_residue_list = sorted(int(r) for r in all_query_residues)
n = len(query_residue_list)
if n < 2:
return AnalysisResult(
data={
"labels": query_residue_list,
"matrix": [],
"residue_to_index": {},
},
metadata={
"database_type": database_type,
"n_residues": n,
"normalize": normalize,
},
)
residue_to_index = {res: idx for idx, res in enumerate(query_residue_list)}
matrix = [[0] * n for _ in range(n)]
# Compute pairwise shared contacts
for query_frames_dict in database_to_query_frames.values():
query_resids = list(query_frames_dict.keys())
for i in range(len(query_resids)):
for j in range(i + 1, len(query_resids)):
res1, res2 = query_resids[i], query_resids[j]
shared_frames = len(
query_frames_dict[res1] & query_frames_dict[res2]
)
if shared_frames > 0:
idx1, idx2 = (
residue_to_index[int(res1)],
residue_to_index[int(res2)],
)
matrix[idx1][idx2] += shared_frames
matrix[idx2][idx1] += shared_frames
if normalize and matrix:
max_val = max(max(row) for row in matrix)
if max_val > 0:
matrix = [[v / max_val for v in row] for row in matrix]
return AnalysisResult(
data={
"labels": query_residue_list,
"matrix": matrix,
"residue_to_index": residue_to_index,
},
metadata={
"database_type": database_type,
"n_residues": n,
"normalize": normalize,
"n_database_molecules": len(database_to_query_frames),
},
)
def _build_inverted_index(
self, filtered_contacts: Dict[int, Dict[int, List[int]]]
) -> Dict[int, Dict[int, Set[int]]]:
"""Build inverted index from contact frames.
Parameters
----------
filtered_contacts : dict
Contact frames dict mapping query_resid -> database_id -> frame list.
Returns
-------
dict
Inverted mapping: database_id -> {query_resid: set of frames}.
"""
database_to_query_frames: Dict[int, Dict[int, Set[int]]] = {}
for query_resid, db_dict in filtered_contacts.items():
for db_id, frames in db_dict.items():
if db_id not in database_to_query_frames:
database_to_query_frames[db_id] = {}
database_to_query_frames[db_id][query_resid] = set(frames)
return database_to_query_frames