Pocket-Gen / utils /protein_ligand.py
Zaixi's picture
1
dcacefd
import sys
sys.path.append("..")
import os
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdchem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable',
'ZnBinder']
ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)}
BOND_TYPES = {t: i for i, t in enumerate(rdchem.BondType.names.values())}
BOND_NAMES = {i: t for i, t in enumerate(rdchem.BondType.names.keys())}
RESTYPE_1to3 = {
"A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS", "Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS",
"I": "ILE", "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO", "S": "SER", "T": "THR", "W": "TRP",
"Y": "TYR", "V": "VAL",
}
ALPHABET = ['#', 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
ATOM_TYPES = [
'', 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
]
RES_ATOM14 = [
[''] * 14,
['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
]
NUM_ATOMS = [4, 5, 11, 8, 8, 6, 9, 9, 4, 10, 8, 8, 9, 8, 11, 7, 6, 7, 14, 12, 7]
class PDBProtein(object):
AA_NAME_SYM = {'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G',
'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'}
AA_NAME_NUMBER = {
k: i + 1 for i, (k, _) in enumerate(AA_NAME_SYM.items())
}
BACKBONE_NAMES = ["CA", "C", "N", "O"]
def __init__(self, data, mode='auto'):
super().__init__()
if (data[-4:].lower() == '.pdb' and mode == 'auto') or mode == 'path':
with open(data, 'r') as f:
self.block = f.read()
else:
self.block = data
self.ptable = Chem.GetPeriodicTable()
# Molecule properties
self.title = None
# Atom properties
self.atoms = []
self.element = []
self.atomic_weight = []
self.pos = []
self.atom_name = []
self.is_backbone = []
self.atom_to_aa_type = []
self.atom2residue = []
# Residue properties
self.residues = []
self.amino_acid = []
self.amino_idx = []
self.center_of_mass = []
self.pos_CA = []
self.pos_C = []
self.pos_N = []
self.pos_O = []
self.residue_natoms = []
self.seq = []
self._parse()
def _enum_formatted_atom_lines(self):
for line in self.block.splitlines():
if line[0:6].strip() == 'ATOM':
element_symb = line[76:78].strip().capitalize()
if len(element_symb) == 0:
element_symb = line[13:14]
yield {
'line': line,
'type': 'ATOM',
'atom_id': int(line[6:11]),
'atom_name': line[12:16].strip(),
'res_name': line[17:20].strip(),
'chain': line[21:22].strip(),
'res_id': int(line[22:26]),
'res_insert_id': line[26:27].strip(),
'x': float(line[30:38]),
'y': float(line[38:46]),
'z': float(line[46:54]),
'occupancy': float(line[54:60]),
'segment': line[72:76].strip(),
'element_symb': element_symb,
'charge': line[78:80].strip(),
}
elif line[0:6].strip() == 'HEADER':
yield {
'type': 'HEADER',
'value': line[10:].strip()
}
elif line[0:6].strip() == 'ENDMDL':
break # Some PDBs have more than 1 model.
else:
yield {
'type': 'others'
}
def _parse(self):
# Process atoms
residues_tmp = {}
num_residue = -1
for atom in self._enum_formatted_atom_lines():
if atom['type'] == 'HEADER':
self.title = atom['value'].lower()
continue
if atom['type'] == 'others':
continue
if atom['atom_name'][0] == 'H' or atom['atom_name'] == 'OXT':
continue
self.atoms.append(atom)
atomic_number = self.ptable.GetAtomicNumber(atom['element_symb'])
next_ptr = len(self.element)
self.element.append(atomic_number)
self.atomic_weight.append(self.ptable.GetAtomicWeight(atomic_number))
self.pos.append(np.array([atom['x'], atom['y'], atom['z']], dtype=np.float32))
self.atom_name.append(atom['atom_name'])
self.is_backbone.append(atom['atom_name'] in self.BACKBONE_NAMES)
self.atom_to_aa_type.append(self.AA_NAME_NUMBER[atom['res_name']])
chain_res_id = '%s_%s_%d_%s' % (atom['chain'], atom['segment'], atom['res_id'], atom['res_insert_id'])
if chain_res_id not in residues_tmp:
num_residue += 1
residues_tmp[chain_res_id] = {
'name': atom['res_name'],
'atoms': [next_ptr],
'chain': atom['chain'],
'segment': atom['segment'],
'res_id': atom['res_id'],
'full_seq_idx': num_residue,
}
else:
assert residues_tmp[chain_res_id]['name'] == atom['res_name']
assert residues_tmp[chain_res_id]['chain'] == atom['chain']
residues_tmp[chain_res_id]['atoms'].append(next_ptr)
self.atom2residue.append(num_residue)
# Process residues
self.residues = [r for _, r in residues_tmp.items()]
for residue in self.residues:
sum_pos = np.zeros([3], dtype=np.float32)
sum_mass = 0.0
for atom_idx in residue['atoms']:
sum_pos += self.pos[atom_idx] * self.atomic_weight[atom_idx]
sum_mass += self.atomic_weight[atom_idx]
if self.atom_name[atom_idx] in self.BACKBONE_NAMES:
residue['pos_%s' % self.atom_name[atom_idx]] = self.pos[atom_idx]
residue['center_of_mass'] = sum_pos / sum_mass
self.residue_natoms.append(len(residue['atoms']))
assert len(residue['atoms']) <= NUM_ATOMS[self.AA_NAME_NUMBER[residue['name']]]
# Process backbone atoms of residues
self.amino_acid.append(self.AA_NAME_NUMBER[residue['name']])
self.center_of_mass.append(residue['center_of_mass'])
self.amino_idx.append(residue['res_id'])
self.seq.append(self.AA_NAME_SYM[residue['name']])
for name in self.BACKBONE_NAMES:
pos_key = 'pos_%s' % name # pos_CA, pos_C, pos_N, pos_O
if pos_key in residue:
getattr(self, pos_key).append(residue[pos_key])
else:
getattr(self, pos_key).append(residue['center_of_mass'])
# convert atom_name to number
self.atom_name = np.array([ATOM_TYPES.index(atom) for atom in self.atom_name])
self.pos = np.array(self.pos, dtype=np.float32)
def to_dict_atom(self):
return {
'element': np.array(self.element, dtype=np.longlong),
'molecule_name': self.title,
'pos': self.pos,
'is_backbone': np.array(self.is_backbone, dtype=bool),
'atom_name': self.atom_name,
'atom_to_aa_type': np.array(self.atom_to_aa_type, dtype=np.longlong),
'atom2residue': np.array(self.atom2residue, dtype=np.longlong)
}
def to_dict_residue(self):
return {
'seq': self.seq,
'res_idx': np.array(self.amino_idx, dtype=np.longlong),
'amino_acid': np.array(self.amino_acid, dtype=np.longlong),
'center_of_mass': np.array(self.center_of_mass, dtype=np.float32),
'pos_CA': np.array(self.pos_CA, dtype=np.float32),
'pos_C': np.array(self.pos_C, dtype=np.float32),
'pos_N': np.array(self.pos_N, dtype=np.float32),
'pos_O': np.array(self.pos_O, dtype=np.float32),
'residue_natoms': np.array(self.residue_natoms, dtype=np.longlong),
}
def query_residues_radius(self, center, radius, criterion='center_of_mass'):
center = np.array(center).reshape(3)
selected = []
for residue in self.residues:
distance = np.linalg.norm(residue[criterion] - center, ord=2)
print(residue[criterion], distance)
if distance < radius:
selected.append(residue)
return selected
def query_residues_ligand(self, ligand, radius=3.5, selected_residue=None, return_mask=True):
selected = []
sel_idx = set()
selected_mask = np.zeros(len(self.residues), dtype=bool)
full_seq_idx = set()
if selected_residue is None:
selected_residue = self.residues
# The time-complexity is O(mn).
for i, residue in enumerate(selected_residue):
for center in ligand['pos']:
distance = np.min(np.linalg.norm(self.pos[residue['atoms']] - center, ord=2, axis=1))
if distance <= radius and i not in sel_idx:
selected.append(residue)
sel_idx.add(i)
full_seq_idx.add(residue['full_seq_idx'])
break
selected_mask[list(sel_idx)] = 1
if return_mask:
return list(full_seq_idx), selected_mask
return list(full_seq_idx), selected
# can be used for select pocket residues
def residues_to_pdb_block(self, residues, name='POCKET'):
block = "HEADER %s\n" % name
block += "COMPND %s\n" % name
for residue in residues:
for atom_idx in residue['atoms']:
block += self.atoms[atom_idx]['line'] + "\n"
block += "END\n"
return block
def return_residues(self):
return self.residues, self.atoms
def parse_pdbbind_index_file(path):
pdb_id = []
with open(path, 'r') as f:
lines = f.readlines()
for line in lines:
if line.startswith('#'): continue
pdb_id.append(line.split()[0])
return pdb_id
def parse_sdf_file(path, feat=True):
mol = Chem.MolFromMolFile(path, sanitize=False)
fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
factory = ChemicalFeatures.BuildFeatureFactory(fdefName)
rd_num_atoms = mol.GetNumAtoms()
feat_mat = np.zeros([rd_num_atoms, len(ATOM_FAMILIES)], dtype=np.longlong)
if feat:
rdmol = next(iter(Chem.SDMolSupplier(path, removeHs=True)))
for feat in factory.GetFeaturesForMol(rdmol):
feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1
with open(path, 'r') as f:
sdf = f.read()
sdf = sdf.splitlines()
num_atoms, num_bonds = map(int, [sdf[3][0:3], sdf[3][3:6]])
assert num_atoms == rd_num_atoms
ptable = Chem.GetPeriodicTable()
element, pos = [], []
accum_pos = np.array([0.0, 0.0, 0.0], dtype=np.float32)
accum_mass = 0.0
for atom_line in map(lambda x: x.split(), sdf[4:4 + num_atoms]):
x, y, z = map(float, atom_line[:3])
symb = atom_line[3]
atomic_number = ptable.GetAtomicNumber(symb.capitalize())
element.append(atomic_number)
pos.append([x, y, z])
atomic_weight = ptable.GetAtomicWeight(atomic_number)
accum_pos += np.array([x, y, z]) * atomic_weight
accum_mass += atomic_weight
center_of_mass = np.array(accum_pos / accum_mass, dtype=np.float32)
element = np.array(element, dtype=np.int32)
pos = np.array(pos, dtype=np.float32)
BOND_TYPES = {t: i for i, t in enumerate(rdchem.BondType.names.values())}
bond_type_map = {
1: BOND_TYPES[rdchem.BondType.SINGLE],
2: BOND_TYPES[rdchem.BondType.DOUBLE],
3: BOND_TYPES[rdchem.BondType.TRIPLE],
4: BOND_TYPES[rdchem.BondType.AROMATIC],
}
row, col, edge_type = [], [], []
for bond_line in sdf[4 + num_atoms:4 + num_atoms + num_bonds]:
start, end = int(bond_line[0:3]) - 1, int(bond_line[3:6]) - 1
row += [start, end]
col += [end, start]
edge_type += 2 * [bond_type_map[int(bond_line[6:9])]]
edge_index = np.array([row, col], dtype=np.longlong)
edge_type = np.array(edge_type, dtype=np.longlong)
perm = (edge_index[0] * num_atoms + edge_index[1]).argsort()
edge_index = edge_index[:, perm]
edge_type = edge_type[perm]
neighbor_dict = {}
# used in rotation angle prediction
for i, atom in enumerate(mol.GetAtoms()):
neighbor_dict[i] = [n.GetIdx() for n in atom.GetNeighbors()]
data = {
'element': element,
'pos': pos,
'bond_index': edge_index,
'bond_type': edge_type,
'center_of_mass': center_of_mass,
'atom_feature': feat_mat,
'neighbors': neighbor_dict
}
return data