"""Atom group extensions for ProLint.
This module provides extended atom group classes with additional
functionality for biomolecular interaction analysis.
"""
from abc import ABC, abstractmethod
from typing import Iterable, Union, Dict, Optional, Any
from collections import Counter
from functools import cached_property
import numpy as np
import MDAnalysis as mda
[docs]
class PLAtomGroupBase(ABC):
"""Abstract base class for ProLint atom group operations.
Defines the interface for atom group manipulation methods
used throughout ProLint.
"""
@abstractmethod
[docs]
def add(
self,
resname: Optional[Union[str, list[str]]] = None,
atomname: Optional[Union[str, list[str]]] = None,
resnum: Optional[Union[int, list[int]]] = None,
atomids: Optional[Union[int, list[int]]] = None,
) -> "ExtendedAtomGroup":
"""Add atoms to the atom group.
Parameters
----------
resname : str or list of str, optional
Residue name(s) to add.
atomname : str or list of str, optional
Atom name(s) to add.
resnum : int or list of int, optional
Residue number(s) to add.
atomids : int or list of int, optional
Atom ID(s) to add.
Returns
-------
ExtendedAtomGroup
New atom group with added atoms.
"""
@abstractmethod
[docs]
def remove(
self,
resname: Optional[Union[str, list[str]]] = None,
atomname: Optional[Union[str, list[str]]] = None,
resnum: Optional[Union[int, list[int]]] = None,
atomids: Optional[Union[int, list[int]]] = None,
) -> "ExtendedAtomGroup":
"""Remove atoms from the atom group.
Parameters
----------
resname : str or list of str, optional
Residue name(s) to remove.
atomname : str or list of str, optional
Atom name(s) to remove.
resnum : int or list of int, optional
Residue number(s) to remove.
atomids : int or list of int, optional
Atom ID(s) to remove.
Returns
-------
ExtendedAtomGroup
New atom group with specified atoms removed.
"""
@abstractmethod
[docs]
def get_resnames(
self, resids: Iterable[int], out: Union[type[list], type[dict]] = list
) -> Union[list[str], Dict[int, str]]:
"""Get residue names for given residue IDs.
Parameters
----------
resids : Iterable of int
Residue IDs to look up.
out : type, default=list
Output format: ``list`` or ``dict``.
Returns
-------
list of str or dict
Residue names as list or as {resid: resname} mapping.
"""
@abstractmethod
[docs]
def get_resids(self, resname: str) -> np.ndarray:
"""Get residue IDs for a given residue name.
Parameters
----------
resname : str
Residue name to look up.
Returns
-------
ndarray
Array of residue IDs matching the residue name.
"""
@abstractmethod
[docs]
def get_all_resids(
self, resnames: Iterable[str], out: Union[type[list], type[dict]] = list
) -> Union[list[np.ndarray], Dict[str, np.ndarray]]:
"""Get residue IDs for multiple residue names.
Parameters
----------
resnames : Iterable of str
Residue names to look up.
out : type, default=list
Output format: ``list`` or ``dict``.
Returns
-------
list or dict
Residue IDs as list of arrays or as {resname: resids} mapping.
"""
@abstractmethod
[docs]
def filter_resids_by_resname(
self, resids: Iterable[int], resname: str
) -> np.ndarray:
"""Filter residue IDs to keep only those matching a residue name.
Parameters
----------
resids : Iterable of int
Residue IDs to filter.
resname : str
Residue name to match.
Returns
-------
ndarray
Subset of input resids that match the residue name.
"""
@property
@abstractmethod
[docs]
def unique_resnames(self) -> np.ndarray:
"""Unique residue names in the atom group.
Returns
-------
ndarray
Array of unique residue names.
"""
@property
@abstractmethod
[docs]
def resname_counts(self) -> Counter:
"""Count of residues for each residue name.
Returns
-------
Counter
Mapping of residue name to count.
"""
[docs]
class ExtendedAtomGroup(mda.AtomGroup, PLAtomGroupBase):
"""Extended MDAnalysis AtomGroup with additional ProLint functionality.
Provides enhanced methods for manipulating and querying atom selections,
with cached properties for efficient repeated access.
Parameters
----------
*args : tuple
Arguments passed to MDAnalysis AtomGroup.
**kwargs : dict
Keyword arguments passed to MDAnalysis AtomGroup.
Examples
--------
>>> from prolint import Universe
>>> u = Universe("topology.gro", "trajectory.xtc")
>>> u.database.unique_resnames
array(['POPC', 'POPE', 'CHOL'], dtype='<U4')
>>> u.database.resname_counts
Counter({'POPC': 128, 'POPE': 64, 'CHOL': 32})
Add atoms:
>>> extended = u.database.add(resname="DPPC")
Remove atoms:
>>> filtered = u.database.remove(resname="CHOL")
See Also
--------
Universe.query : Query atom group (typically protein)
Universe.database : Database atom group (typically lipids)
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Cached properties will be computed on first access
self._stored_resnames = self.residues.resnames
self._stored_resids = self.residues.resids
@cached_property
def _resname_resid_labels(self) -> Dict[int, str]:
resnames = self.residues.resnames
resids = self.residues.resids
return dict(zip(resids, resnames))
def _build_selection_string(
self,
resname: Optional[Union[str, list[str]]] = None,
atomname: Optional[Union[str, list[str]]] = None,
resnum: Optional[Union[int, list[int]]] = None,
atomids: Optional[Union[int, list[int]]] = None,
) -> str:
"""Build MDAnalysis selection string from filter criteria.
Parameters
----------
resname : str or list of str, optional
Residue name(s).
atomname : str or list of str, optional
Atom name(s).
resnum : int or list of int, optional
Residue number(s).
atomids : int or list of int, optional
Atom ID(s).
Returns
-------
str
MDAnalysis selection string.
Raises
------
ValueError
If no selection criteria are provided.
"""
selections = []
if resname is not None:
if isinstance(resname, str):
resname = [resname]
selections.append("resname " + " or resname ".join(resname))
if atomname is not None:
if isinstance(atomname, str):
atomname = [atomname]
selections.append("name " + " or name ".join(atomname))
if resnum is not None:
if isinstance(resnum, int):
resnum = [resnum]
resnum_str = list(map(str, resnum))
selections.append("resid " + " or resid ".join(resnum_str))
if atomids is not None:
if isinstance(atomids, int):
atomids = [atomids]
atomids_str = list(map(str, atomids))
selections.append("bynum " + " or bynum ".join(atomids_str))
if not selections:
raise ValueError("At least one selection criterion must be provided")
return " or ".join(selections)
[docs]
def add(
self,
resname: Optional[Union[str, list[str]]] = None,
atomname: Optional[Union[str, list[str]]] = None,
resnum: Optional[Union[int, list[int]]] = None,
atomids: Optional[Union[int, list[int]]] = None,
) -> "ExtendedAtomGroup":
"""Add atoms to the atom group.
See :meth:`PLAtomGroupBase.add` for parameter documentation.
"""
selection_string = self._build_selection_string(
resname, atomname, resnum, atomids
)
new_group = self.universe.atoms.select_atoms(selection_string)
new_group = self | new_group
return self.__class__(new_group)
[docs]
def remove(
self,
resname: Optional[Union[str, list[str]]] = None,
atomname: Optional[Union[str, list[str]]] = None,
resnum: Optional[Union[int, list[int]]] = None,
atomids: Optional[Union[int, list[int]]] = None,
) -> "ExtendedAtomGroup":
"""Remove atoms from the atom group.
See :meth:`PLAtomGroupBase.remove` for parameter documentation.
"""
selection_string = self._build_selection_string(
resname, atomname, resnum, atomids
)
atoms_to_remove = self.select_atoms(selection_string)
new_group = self - atoms_to_remove
return self.__class__(new_group)
[docs]
def get_resnames(
self, resids: Iterable[int], out: Union[type[list], type[dict]] = list
) -> Union[list[str], Dict[int, str]]:
"""Get residue names for given residue IDs.
See :meth:`PLAtomGroupBase.get_resnames` for parameter documentation.
"""
if out is list:
return [self._resname_resid_labels[resid] for resid in resids]
elif out is dict:
return {resid: self._resname_resid_labels[resid] for resid in resids}
else:
raise ValueError("out must be either list or dict")
[docs]
def get_resids(self, resname: str) -> np.ndarray:
"""Get residue IDs for a given residue name.
See :meth:`PLAtomGroupBase.get_resids` for parameter documentation.
"""
return self.residues.resids[self.residues.resnames == resname]
[docs]
def get_all_resids(
self, resnames: Iterable[str], out: Union[type[list], type[dict]] = list
) -> Union[list[np.ndarray], Dict[str, np.ndarray]]:
"""Get residue IDs for multiple residue names.
See :meth:`PLAtomGroupBase.get_all_resids` for parameter documentation.
"""
if out is list:
return [self.get_resids(resname) for resname in resnames]
elif out is dict:
return {resname: self.get_resids(resname) for resname in resnames}
else:
raise ValueError("out must be either list or dict")
[docs]
def filter_resids_by_resname(
self, resids: Iterable[int], resname: str
) -> np.ndarray:
"""Filter residue IDs to keep only those matching a residue name.
See :meth:`PLAtomGroupBase.filter_resids_by_resname` for parameter documentation.
"""
resids = np.asarray(resids)
all_resnames = self._stored_resnames
all_resids = self._stored_resids
indices = np.searchsorted(all_resids, resids)
return resids[np.where(all_resnames[indices] == resname)[0]]
@property
[docs]
def unique_resnames(self) -> np.ndarray:
"""Unique residue names in the atom group.
Returns
-------
ndarray
Array of unique residue names.
"""
return np.unique(self.residues.resnames) # type: ignore[return-value]
@property
[docs]
def resname_counts(self) -> Counter:
"""Count of residues for each residue name.
Returns
-------
Counter
Mapping of residue name to count.
"""
return Counter(self.residues.resnames)
def __str__(self) -> str:
return f"<ProLint Wrapper for {super().__str__()}>"
def __repr__(self) -> str:
return f"<ProLint Wrapper for {super().__repr__()}>"