Spaces:
Sleeping
Sleeping
| import sys | |
| sys.path.append("..") | |
| import os | |
| import numpy as np | |
| from rdkit import Chem | |
| from rdkit.Chem import rdchem | |
| from rdkit.Chem import ChemicalFeatures | |
| from rdkit import RDConfig | |
| ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable', | |
| 'ZnBinder'] | |
| ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)} | |
| BOND_TYPES = {t: i for i, t in enumerate(rdchem.BondType.names.values())} | |
| BOND_NAMES = {i: t for i, t in enumerate(rdchem.BondType.names.keys())} | |
| RESTYPE_1to3 = { | |
| "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS", "Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS", | |
| "I": "ILE", "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO", "S": "SER", "T": "THR", "W": "TRP", | |
| "Y": "TYR", "V": "VAL", | |
| } | |
| ALPHABET = ['#', 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] | |
| ATOM_TYPES = [ | |
| '', 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', | |
| 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', | |
| 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', | |
| 'CZ3', 'NZ', 'OXT' | |
| ] | |
| RES_ATOM14 = [ | |
| [''] * 14, | |
| ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], | |
| ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], | |
| ] | |
| NUM_ATOMS = [4, 5, 11, 8, 8, 6, 9, 9, 4, 10, 8, 8, 9, 8, 11, 7, 6, 7, 14, 12, 7] | |
| class PDBProtein(object): | |
| AA_NAME_SYM = {'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', | |
| 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S', | |
| 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'} | |
| AA_NAME_NUMBER = { | |
| k: i + 1 for i, (k, _) in enumerate(AA_NAME_SYM.items()) | |
| } | |
| BACKBONE_NAMES = ["CA", "C", "N", "O"] | |
| def __init__(self, data, mode='auto'): | |
| super().__init__() | |
| if (data[-4:].lower() == '.pdb' and mode == 'auto') or mode == 'path': | |
| with open(data, 'r') as f: | |
| self.block = f.read() | |
| else: | |
| self.block = data | |
| self.ptable = Chem.GetPeriodicTable() | |
| # Molecule properties | |
| self.title = None | |
| # Atom properties | |
| self.atoms = [] | |
| self.element = [] | |
| self.atomic_weight = [] | |
| self.pos = [] | |
| self.atom_name = [] | |
| self.is_backbone = [] | |
| self.atom_to_aa_type = [] | |
| self.atom2residue = [] | |
| # Residue properties | |
| self.residues = [] | |
| self.amino_acid = [] | |
| self.amino_idx = [] | |
| self.center_of_mass = [] | |
| self.pos_CA = [] | |
| self.pos_C = [] | |
| self.pos_N = [] | |
| self.pos_O = [] | |
| self.residue_natoms = [] | |
| self.seq = [] | |
| self._parse() | |
| def _enum_formatted_atom_lines(self): | |
| for line in self.block.splitlines(): | |
| if line[0:6].strip() == 'ATOM': | |
| element_symb = line[76:78].strip().capitalize() | |
| if len(element_symb) == 0: | |
| element_symb = line[13:14] | |
| yield { | |
| 'line': line, | |
| 'type': 'ATOM', | |
| 'atom_id': int(line[6:11]), | |
| 'atom_name': line[12:16].strip(), | |
| 'res_name': line[17:20].strip(), | |
| 'chain': line[21:22].strip(), | |
| 'res_id': int(line[22:26]), | |
| 'res_insert_id': line[26:27].strip(), | |
| 'x': float(line[30:38]), | |
| 'y': float(line[38:46]), | |
| 'z': float(line[46:54]), | |
| 'occupancy': float(line[54:60]), | |
| 'segment': line[72:76].strip(), | |
| 'element_symb': element_symb, | |
| 'charge': line[78:80].strip(), | |
| } | |
| elif line[0:6].strip() == 'HEADER': | |
| yield { | |
| 'type': 'HEADER', | |
| 'value': line[10:].strip() | |
| } | |
| elif line[0:6].strip() == 'ENDMDL': | |
| break # Some PDBs have more than 1 model. | |
| else: | |
| yield { | |
| 'type': 'others' | |
| } | |
| def _parse(self): | |
| # Process atoms | |
| residues_tmp = {} | |
| num_residue = -1 | |
| for atom in self._enum_formatted_atom_lines(): | |
| if atom['type'] == 'HEADER': | |
| self.title = atom['value'].lower() | |
| continue | |
| if atom['type'] == 'others': | |
| continue | |
| if atom['atom_name'][0] == 'H' or atom['atom_name'] == 'OXT': | |
| continue | |
| self.atoms.append(atom) | |
| atomic_number = self.ptable.GetAtomicNumber(atom['element_symb']) | |
| next_ptr = len(self.element) | |
| self.element.append(atomic_number) | |
| self.atomic_weight.append(self.ptable.GetAtomicWeight(atomic_number)) | |
| self.pos.append(np.array([atom['x'], atom['y'], atom['z']], dtype=np.float32)) | |
| self.atom_name.append(atom['atom_name']) | |
| self.is_backbone.append(atom['atom_name'] in self.BACKBONE_NAMES) | |
| self.atom_to_aa_type.append(self.AA_NAME_NUMBER[atom['res_name']]) | |
| chain_res_id = '%s_%s_%d_%s' % (atom['chain'], atom['segment'], atom['res_id'], atom['res_insert_id']) | |
| if chain_res_id not in residues_tmp: | |
| num_residue += 1 | |
| residues_tmp[chain_res_id] = { | |
| 'name': atom['res_name'], | |
| 'atoms': [next_ptr], | |
| 'chain': atom['chain'], | |
| 'segment': atom['segment'], | |
| 'res_id': atom['res_id'], | |
| 'full_seq_idx': num_residue, | |
| } | |
| else: | |
| assert residues_tmp[chain_res_id]['name'] == atom['res_name'] | |
| assert residues_tmp[chain_res_id]['chain'] == atom['chain'] | |
| residues_tmp[chain_res_id]['atoms'].append(next_ptr) | |
| self.atom2residue.append(num_residue) | |
| # Process residues | |
| self.residues = [r for _, r in residues_tmp.items()] | |
| for residue in self.residues: | |
| sum_pos = np.zeros([3], dtype=np.float32) | |
| sum_mass = 0.0 | |
| for atom_idx in residue['atoms']: | |
| sum_pos += self.pos[atom_idx] * self.atomic_weight[atom_idx] | |
| sum_mass += self.atomic_weight[atom_idx] | |
| if self.atom_name[atom_idx] in self.BACKBONE_NAMES: | |
| residue['pos_%s' % self.atom_name[atom_idx]] = self.pos[atom_idx] | |
| residue['center_of_mass'] = sum_pos / sum_mass | |
| self.residue_natoms.append(len(residue['atoms'])) | |
| assert len(residue['atoms']) <= NUM_ATOMS[self.AA_NAME_NUMBER[residue['name']]] | |
| # Process backbone atoms of residues | |
| self.amino_acid.append(self.AA_NAME_NUMBER[residue['name']]) | |
| self.center_of_mass.append(residue['center_of_mass']) | |
| self.amino_idx.append(residue['res_id']) | |
| self.seq.append(self.AA_NAME_SYM[residue['name']]) | |
| for name in self.BACKBONE_NAMES: | |
| pos_key = 'pos_%s' % name # pos_CA, pos_C, pos_N, pos_O | |
| if pos_key in residue: | |
| getattr(self, pos_key).append(residue[pos_key]) | |
| else: | |
| getattr(self, pos_key).append(residue['center_of_mass']) | |
| # convert atom_name to number | |
| self.atom_name = np.array([ATOM_TYPES.index(atom) for atom in self.atom_name]) | |
| self.pos = np.array(self.pos, dtype=np.float32) | |
| def to_dict_atom(self): | |
| return { | |
| 'element': np.array(self.element, dtype=np.longlong), | |
| 'molecule_name': self.title, | |
| 'pos': self.pos, | |
| 'is_backbone': np.array(self.is_backbone, dtype=bool), | |
| 'atom_name': self.atom_name, | |
| 'atom_to_aa_type': np.array(self.atom_to_aa_type, dtype=np.longlong), | |
| 'atom2residue': np.array(self.atom2residue, dtype=np.longlong) | |
| } | |
| def to_dict_residue(self): | |
| return { | |
| 'seq': self.seq, | |
| 'res_idx': np.array(self.amino_idx, dtype=np.longlong), | |
| 'amino_acid': np.array(self.amino_acid, dtype=np.longlong), | |
| 'center_of_mass': np.array(self.center_of_mass, dtype=np.float32), | |
| 'pos_CA': np.array(self.pos_CA, dtype=np.float32), | |
| 'pos_C': np.array(self.pos_C, dtype=np.float32), | |
| 'pos_N': np.array(self.pos_N, dtype=np.float32), | |
| 'pos_O': np.array(self.pos_O, dtype=np.float32), | |
| 'residue_natoms': np.array(self.residue_natoms, dtype=np.longlong), | |
| } | |
| def query_residues_radius(self, center, radius, criterion='center_of_mass'): | |
| center = np.array(center).reshape(3) | |
| selected = [] | |
| for residue in self.residues: | |
| distance = np.linalg.norm(residue[criterion] - center, ord=2) | |
| print(residue[criterion], distance) | |
| if distance < radius: | |
| selected.append(residue) | |
| return selected | |
| def query_residues_ligand(self, ligand, radius=3.5, selected_residue=None, return_mask=True): | |
| selected = [] | |
| sel_idx = set() | |
| selected_mask = np.zeros(len(self.residues), dtype=bool) | |
| full_seq_idx = set() | |
| if selected_residue is None: | |
| selected_residue = self.residues | |
| # The time-complexity is O(mn). | |
| for i, residue in enumerate(selected_residue): | |
| for center in ligand['pos']: | |
| distance = np.min(np.linalg.norm(self.pos[residue['atoms']] - center, ord=2, axis=1)) | |
| if distance <= radius and i not in sel_idx: | |
| selected.append(residue) | |
| sel_idx.add(i) | |
| full_seq_idx.add(residue['full_seq_idx']) | |
| break | |
| selected_mask[list(sel_idx)] = 1 | |
| if return_mask: | |
| return list(full_seq_idx), selected_mask | |
| return list(full_seq_idx), selected | |
| # can be used for select pocket residues | |
| def residues_to_pdb_block(self, residues, name='POCKET'): | |
| block = "HEADER %s\n" % name | |
| block += "COMPND %s\n" % name | |
| for residue in residues: | |
| for atom_idx in residue['atoms']: | |
| block += self.atoms[atom_idx]['line'] + "\n" | |
| block += "END\n" | |
| return block | |
| def return_residues(self): | |
| return self.residues, self.atoms | |
| def parse_pdbbind_index_file(path): | |
| pdb_id = [] | |
| with open(path, 'r') as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| if line.startswith('#'): continue | |
| pdb_id.append(line.split()[0]) | |
| return pdb_id | |
| def parse_sdf_file(path, feat=True): | |
| mol = Chem.MolFromMolFile(path, sanitize=False) | |
| fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') | |
| factory = ChemicalFeatures.BuildFeatureFactory(fdefName) | |
| rd_num_atoms = mol.GetNumAtoms() | |
| feat_mat = np.zeros([rd_num_atoms, len(ATOM_FAMILIES)], dtype=np.longlong) | |
| if feat: | |
| rdmol = next(iter(Chem.SDMolSupplier(path, removeHs=True))) | |
| for feat in factory.GetFeaturesForMol(rdmol): | |
| feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1 | |
| with open(path, 'r') as f: | |
| sdf = f.read() | |
| sdf = sdf.splitlines() | |
| num_atoms, num_bonds = map(int, [sdf[3][0:3], sdf[3][3:6]]) | |
| assert num_atoms == rd_num_atoms | |
| ptable = Chem.GetPeriodicTable() | |
| element, pos = [], [] | |
| accum_pos = np.array([0.0, 0.0, 0.0], dtype=np.float32) | |
| accum_mass = 0.0 | |
| for atom_line in map(lambda x: x.split(), sdf[4:4 + num_atoms]): | |
| x, y, z = map(float, atom_line[:3]) | |
| symb = atom_line[3] | |
| atomic_number = ptable.GetAtomicNumber(symb.capitalize()) | |
| element.append(atomic_number) | |
| pos.append([x, y, z]) | |
| atomic_weight = ptable.GetAtomicWeight(atomic_number) | |
| accum_pos += np.array([x, y, z]) * atomic_weight | |
| accum_mass += atomic_weight | |
| center_of_mass = np.array(accum_pos / accum_mass, dtype=np.float32) | |
| element = np.array(element, dtype=np.int32) | |
| pos = np.array(pos, dtype=np.float32) | |
| BOND_TYPES = {t: i for i, t in enumerate(rdchem.BondType.names.values())} | |
| bond_type_map = { | |
| 1: BOND_TYPES[rdchem.BondType.SINGLE], | |
| 2: BOND_TYPES[rdchem.BondType.DOUBLE], | |
| 3: BOND_TYPES[rdchem.BondType.TRIPLE], | |
| 4: BOND_TYPES[rdchem.BondType.AROMATIC], | |
| } | |
| row, col, edge_type = [], [], [] | |
| for bond_line in sdf[4 + num_atoms:4 + num_atoms + num_bonds]: | |
| start, end = int(bond_line[0:3]) - 1, int(bond_line[3:6]) - 1 | |
| row += [start, end] | |
| col += [end, start] | |
| edge_type += 2 * [bond_type_map[int(bond_line[6:9])]] | |
| edge_index = np.array([row, col], dtype=np.longlong) | |
| edge_type = np.array(edge_type, dtype=np.longlong) | |
| perm = (edge_index[0] * num_atoms + edge_index[1]).argsort() | |
| edge_index = edge_index[:, perm] | |
| edge_type = edge_type[perm] | |
| neighbor_dict = {} | |
| # used in rotation angle prediction | |
| for i, atom in enumerate(mol.GetAtoms()): | |
| neighbor_dict[i] = [n.GetIdx() for n in atom.GetNeighbors()] | |
| data = { | |
| 'element': element, | |
| 'pos': pos, | |
| 'bond_index': edge_index, | |
| 'bond_type': edge_type, | |
| 'center_of_mass': center_of_mass, | |
| 'atom_feature': feat_mat, | |
| 'neighbors': neighbor_dict | |
| } | |
| return data | |