import rdkit import rdkit.Chem as Chem from scipy.sparse import csr_matrix from scipy.sparse.csgraph import minimum_spanning_tree from collections import defaultdict from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions from rdkit.Chem.Descriptors import MolLogP, qed from torch_geometric.data import Data, Batch from random import sample from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule import numpy as np from math import sqrt import torch from rdkit.Chem import BRICS from copy import deepcopy MST_MAX_WEIGHT = 100 MAX_NCAND = 2000 def vina_score(mol): ligand_rdmol = Chem.AddHs(mol, addCoords=True) if use_uff: UFFOptimizeMolecule(ligand_rdmol) def lipinski(mol): if qed(mol)<=5 and Chem.Lipinski.NumHDonors(mol)<=5 and Chem.Lipinski.NumHAcceptors(mol)<=10 and Chem.Descriptors.ExactMolWt(mol)<=500 and Chem.Lipinski.NumRotatableBonds(mol)<=5: return True else: return False def list_filter(a,b): filter = [] for i in a: if i in b: filter.append(i) return filter def rand_rotate(dir, ref, pos, alpha=None): #dir = dir/torch.norm(dir) if alpha is None: alpha = torch.randn(1) n_pos = pos.shape[0] sin, cos = torch.sin(alpha), torch.cos(alpha) K = 1 - cos M = torch.dot(dir, ref) nx, ny, nz = dir[0], dir[1], dir[2] x0, y0, z0 = ref[0], ref[1], ref[2] T = torch.tensor([nx ** 2 * K + cos, nx * ny * K - nz * sin, nx * nz * K + ny * sin, (x0 - nx * M) * K + (nz * y0 - ny * z0) * sin, nx * ny * K + nz * sin, ny ** 2 * K + cos, ny * nz * K - nx * sin, (y0 - ny * M) * K + (nx * z0 - nz * x0) * sin, nx * nz * K - ny * sin, ny * nz * K + nx * sin, nz ** 2 * K + cos, (z0 - nz * M) * K + (ny * x0 - nx * y0) * sin, 0, 0, 0, 1]).reshape(4, 4) pos = torch.cat([pos.t(), torch.ones(n_pos).unsqueeze(0)], dim=0) rotated_pos = torch.mm(T, pos)[:3] return rotated_pos.t() def kabsch(A, B): # Input: # Nominal A Nx3 matrix of points # Measured B Nx3 matrix of points # Returns R,t # R = 3x3 rotation matrix (B to A) # t = 3x1 translation vector (B to A) assert len(A) == len(B) N = A.shape[0] # total points centroid_A = np.mean(A, axis=0) centroid_B = np.mean(B, axis=0) # center the points AA = A - np.tile(centroid_A, (N, 1)) BB = B - np.tile(centroid_B, (N, 1)) H = np.transpose(BB) * AA U, S, Vt = np.linalg.svd(H) R = Vt.T * U.T # special reflection case if np.linalg.det(R) < 0: Vt[2, :] *= -1 R = Vt.T * U.T t = -R * centroid_B.T + centroid_A.T return R, t def kabsch_torch(A, B, C): A=A.double() B=B.double() C=C.double() a_mean = A.mean(dim=0, keepdims=True) b_mean = B.mean(dim=0, keepdims=True) A_c = A - a_mean B_c = B - b_mean # Covariance matrix H = torch.matmul(A_c.transpose(0,1), B_c) # [B, 3, 3] U, S, V = torch.svd(H) # Rotation matrix R = torch.matmul(V, U.transpose(0,1)) # [B, 3, 3] # Translation vector t = b_mean - torch.matmul(R, a_mean.transpose(0,1)).transpose(0,1) C_aligned = torch.matmul(R, C.transpose(0,1)).transpose(0,1) + t return C_aligned, R, t def eig_coord_from_dist(D): M = (D[:1, :] + D[:, :1] - D) / 2 L, V = torch.linalg.eigh(M) L = torch.diag_embed(L) X = torch.matmul(V, L.clamp(min=0).sqrt()) return X[:, -3:].detach() def self_square_dist(X): dX = X.unsqueeze(0) - X.unsqueeze(1) # [1, N, 3] - [N, 1, 3] D = torch.sum(dX**2, dim=-1) return D def set_atommap(mol, num=0): for atom in mol.GetAtoms(): atom.SetAtomMapNum(num) def get_mol(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return None Chem.Kekulize(mol) return mol def get_smiles(mol): return Chem.MolToSmiles(mol, kekuleSmiles=True) def decode_stereo(smiles2D): mol = Chem.MolFromSmiles(smiles2D) dec_isomers = list(EnumerateStereoisomers(mol)) dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers] smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers] chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"] if len(chiralN) > 0: for mol in dec_isomers: for idx in chiralN: mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True)) return smiles3D def sanitize(mol): try: smiles = get_smiles(mol) mol = get_mol(smiles) except Exception as e: return None return mol def copy_atom(atom): new_atom = Chem.Atom(atom.GetSymbol()) new_atom.SetFormalCharge(atom.GetFormalCharge()) new_atom.SetAtomMapNum(atom.GetAtomMapNum()) return new_atom def copy_edit_mol(mol): new_mol = Chem.RWMol(Chem.MolFromSmiles('')) for atom in mol.GetAtoms(): new_atom = copy_atom(atom) new_mol.AddAtom(new_atom) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx() bt = bond.GetBondType() new_mol.AddBond(a1, a2, bt) return new_mol def get_submol(mol, idxs, mark=[]): new_mol = Chem.RWMol(Chem.MolFromSmiles('')) map = {} for atom in mol.GetAtoms(): if atom.GetIdx() in idxs: new_atom = copy_atom(atom) if atom.GetIdx() in mark: new_atom.SetAtomMapNum(1) else: new_atom.SetAtomMapNum(0) map[atom.GetIdx()] = new_mol.AddAtom(new_atom) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx() if a1 in idxs and a2 in idxs: bt = bond.GetBondType() new_mol.AddBond(map[a1], map[a2], bt) return new_mol.GetMol() def get_clique_mol(mol, atoms): smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True) new_mol = Chem.MolFromSmiles(smiles, sanitize=False) new_mol = copy_edit_mol(new_mol).GetMol() new_mol = sanitize(new_mol) # We assume this is not None return new_mol def get_clique_mol_simple(mol, cluster): smile_cluster = Chem.MolFragmentToSmiles(mol, cluster, canonical=True, kekuleSmiles=True) mol_cluster = Chem.MolFromSmiles(smile_cluster, sanitize=False) return mol_cluster def tree_decomp(mol, reference_vocab=None): edges = defaultdict(int) n_atoms = mol.GetNumAtoms() clusters = [] for bond in mol.GetBonds(): a1 = bond.GetBeginAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx() if not bond.IsInRing(): clusters.append({a1, a2}) ssr = [set(x) for x in Chem.GetSymmSSSR(mol)] # remove too large circles ssr = [x for x in ssr if len(x) <= 8] clusters.extend(ssr) nei_list = [[] for _ in range(n_atoms)] for i in range(len(clusters)): for atom in clusters[i]: nei_list[atom].append(i) # Merge Rings with intersection > 2 atoms/ at least 3 joint atoms # check the reference_vocab if it is not None for i in range(len(clusters)): if len(clusters[i]) <= 2: continue for atom in clusters[i]: for j in nei_list[atom]: if i >= j or len(clusters[j]) <= 2: continue inter = clusters[i] & clusters[j] if len(inter) > 2: merge = clusters[i] | clusters[j] if reference_vocab is not None: smile_merge = Chem.MolFragmentToSmiles(mol, merge, canonical=True, kekuleSmiles=True) if reference_vocab[smile_merge] <= 99: continue clusters[i] = merge clusters[j] = set() clusters = [c for c in clusters if len(c) > 0] nei_list = [[] for _ in range(n_atoms)] for i in range(len(clusters)): for atom in clusters[i]: nei_list[atom].append(i) # Build edges for atom in range(n_atoms): if len(nei_list[atom]) <= 1: continue cnei = nei_list[atom] for i in range(len(cnei)): for j in range(i + 1, len(cnei)): c1, c2 = cnei[i], cnei[j] inter = set(clusters[c1]) & set(clusters[c2]) if edges[(c1, c2)] < len(inter): edges[(c1, c2)] = len(inter) # cnei[i] < cnei[j] by construction edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()] if len(edges) == 0: return clusters, edges # Compute Maximum Spanning Tree row, col, data = zip(*edges) n_clique = len(clusters) clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique)) junc_tree = minimum_spanning_tree(clique_graph) row, col = junc_tree.nonzero() edges = [(row[i], col[i]) for i in range(len(row))] return clusters, edges def Brics_decomp(mol, reference_vocab=None): edges = defaultdict(int) n_atoms = mol.GetNumAtoms() clusters = [] for bond in mol.GetBonds(): a1 = bond.GetBeginAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx() if not bond.GetBeginAtom().IsInRing() and not bond.GetEndAtom().IsInRing(): clusters.append({a1, a2}) ''' bre = list(BRICS.FindBRICSBonds(mol)) if len(bre) != 0: for bond in bre: if [bond[0][0], bond[0][1]] in clusters: clusters.remove([bond[0][0], bond[0][1]]) else: clusters.remove([bond[0][1], bond[0][0]]) clusters.append([bond[0][0]]) clusters.append([bond[0][1]])''' ssr = [set(x) for x in Chem.GetSymmSSSR(mol)] # remove too large circles ssr = [x for x in ssr if len(x) <= 8] clusters.extend(ssr) # merge clusters for c in range(len(clusters) - 1): if c >= len(clusters): break for k in range(c + 1, len(clusters)): if k >= len(clusters): break if len(set(clusters[c]) & set(clusters[k])) > 1: clusters[c] = list(set(clusters[c]) | set(clusters[k])) clusters[k] = [] clusters = [c for c in clusters if len(c) > 0] clusters = [c for c in clusters if len(c) > 0] edges = [(0, 0)] return clusters, edges def atom_equal(a1, a2): return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge() # Bond type not considered because all aromatic (so SINGLE matches DOUBLE) def ring_bond_equal(bond1, bond2, reverse=False): b1 = (bond1.GetBeginAtom(), bond1.GetEndAtom()) if reverse: b2 = (bond2.GetEndAtom(), bond2.GetBeginAtom()) else: b2 = (bond2.GetBeginAtom(), bond2.GetEndAtom()) return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) and bond1.GetBondType() == bond2.GetBondType() def attach(ctr_mol, nei_mol, amap): ctr_mol = Chem.RWMol(ctr_mol) for atom in nei_mol.GetAtoms(): if atom.GetIdx() not in amap: new_atom = copy_atom(atom) new_atom.SetAtomMapNum(2) amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom) for bond in nei_mol.GetBonds(): a1 = amap[bond.GetBeginAtom().GetIdx()] a2 = amap[bond.GetEndAtom().GetIdx()] if ctr_mol.GetBondBetweenAtoms(a1, a2) is None: ctr_mol.AddBond(a1, a2, bond.GetBondType()) return ctr_mol.GetMol(), amap def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap): prev_nids = [node.nid for node in prev_nodes] for nei_node in prev_nodes + neighbors: nei_id, nei_mol = nei_node.nid, nei_node.mol amap = nei_amap[nei_id] for atom in nei_mol.GetAtoms(): if atom.GetIdx() not in amap: new_atom = copy_atom(atom) amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom) if nei_mol.GetNumBonds() == 0: nei_atom = nei_mol.GetAtomWithIdx(0) ctr_atom = ctr_mol.GetAtomWithIdx(amap[0]) ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum()) else: for bond in nei_mol.GetBonds(): a1 = amap[bond.GetBeginAtom().GetIdx()] a2 = amap[bond.GetEndAtom().GetIdx()] if ctr_mol.GetBondBetweenAtoms(a1, a2) is None: ctr_mol.AddBond(a1, a2, bond.GetBondType()) elif nei_id in prev_nids: # father node overrides ctr_mol.RemoveBond(a1, a2) ctr_mol.AddBond(a1, a2, bond.GetBondType()) return ctr_mol def local_attach(ctr_mol, neighbors, prev_nodes, amap_list): ctr_mol = copy_edit_mol(ctr_mol) nei_amap = {nei.nid: {} for nei in prev_nodes + neighbors} for nei_id, ctr_atom, nei_atom in amap_list: nei_amap[nei_id][nei_atom] = ctr_atom ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap) return ctr_mol.GetMol() # This version records idx mapping between ctr_mol and nei_mol def enum_attach(ctr_mol, nei_mol): try: Chem.Kekulize(ctr_mol) Chem.Kekulize(nei_mol) except: return [] att_confs = [] valence_ctr = {i: 0 for i in range(ctr_mol.GetNumAtoms())} valence_nei = {i: 0 for i in range(nei_mol.GetNumAtoms())} ctr_bonds = [bond for bond in ctr_mol.GetBonds() if bond.GetBeginAtom().GetAtomMapNum() == 1 and bond.GetEndAtom().GetAtomMapNum() == 1] ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetAtomMapNum() == 1] if nei_mol.GetNumBonds() == 1: # neighbor is a bond bond = nei_mol.GetBondWithIdx(0) #bond_val = int(bond.GetBondType()) bond_val = int(bond.GetBondTypeAsDouble()) b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom() for atom in ctr_atoms: # Optimize if atom is carbon (other atoms may change valence) if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val: continue if atom_equal(atom, b1): new_amap = {b1.GetIdx(): atom.GetIdx()} att_confs.append(new_amap) elif atom_equal(atom, b2): new_amap = {b2.GetIdx(): atom.GetIdx()} att_confs.append(new_amap) else: # intersection is an atom for a1 in ctr_atoms: for a2 in nei_mol.GetAtoms(): if atom_equal(a1, a2): # Optimize if atom is carbon (other atoms may change valence) if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4: continue amap = {a2.GetIdx(): a1.GetIdx()} att_confs.append(amap) # intersection is an bond if ctr_mol.GetNumBonds() > 1: for b1 in ctr_bonds: for b2 in nei_mol.GetBonds(): if ring_bond_equal(b1, b2): amap = {b2.GetBeginAtom().GetIdx(): b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx(): b1.GetEndAtom().GetIdx()} att_confs.append(amap) if ring_bond_equal(b1, b2, reverse=True): amap = {b2.GetEndAtom().GetIdx(): b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx(): b1.GetEndAtom().GetIdx()} att_confs.append(amap) return att_confs def enumerate_assemble(mol, idxs, current, next): ctr_mol = get_submol(mol, idxs, mark=current.clique) ground_truth = get_submol(mol, list(set(idxs) | set(next.clique))) # submol can also obtained with get_clique_mol, future exploration ground_truth_smiles = get_smiles(ground_truth) cand_smiles = [] cand_mols = [] cand_amap = enum_attach(ctr_mol, next.mol) for amap in cand_amap: try: cand_mol, _ = attach(ctr_mol, next.mol, amap) cand_mol = sanitize(cand_mol) except: continue if cand_mol is None: continue smiles = get_smiles(cand_mol) if smiles in cand_smiles or smiles == ground_truth_smiles: continue cand_smiles.append(smiles) cand_mols.append(cand_mol) if len(cand_mols) >= 1: cand_mols = sample(cand_mols, 1) cand_mols.append(ground_truth) labels = torch.tensor([0, 1]) else: cand_mols = [ground_truth] labels = torch.tensor([1]) return labels, cand_mols # allowable node and edge features allowable_features = { 'possible_atomic_num_list' : list(range(1, 119)), 'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], 'possible_chirality_list' : [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER ], 'possible_hybridization_list' : [ Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED ], 'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8], 'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6], 'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'possible_bonds' : [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC ], 'possible_bond_dirs' : [ # only for double bond stereo information Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT, Chem.rdchem.BondDir.ENDDOWNRIGHT ] } def mol_to_graph_data_obj_simple(mol): """ Converts rdkit mol object to graph Data object required by the pytorch geometric package. NB: Uses simplified atom and bond features, and represent as indices :param mol: rdkit mol object :return: graph data object with the attributes: x, edge_index, edge_attr """ # atoms num_atom_features = 2 # atom type, chirality tag atom_features_list = [] for atom in mol.GetAtoms(): atom_feature = [allowable_features['possible_atomic_num_list'].index( atom.GetAtomicNum())] + [allowable_features[ 'possible_chirality_list'].index(atom.GetChiralTag())] atom_features_list.append(atom_feature) x = torch.tensor(np.array(atom_features_list), dtype=torch.long) # bonds num_bond_features = 2 # bond type, bond direction if len(mol.GetBonds()) > 0: # mol has bonds edges_list = [] edge_features_list = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_feature = [allowable_features['possible_bonds'].index( bond.GetBondType())] + [allowable_features[ 'possible_bond_dirs'].index( bond.GetBondDir())] edges_list.append((i, j)) edge_features_list.append(edge_feature) edges_list.append((j, i)) edge_features_list.append(edge_feature) # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) else: # mol has no bonds edge_index = torch.empty((2, 0), dtype=torch.long) edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) return data # For inference def assemble(mol_list, next_motif_smiles): attach_fail = torch.zeros(len(mol_list)).bool() cand_mols, cand_batch, new_atoms, cand_smiles, one_atom_attach, intersection = [], [], [], [], [], [] for i in range(len(mol_list)): next = Chem.MolFromSmiles(next_motif_smiles[i]) cand_amap = enum_attach(mol_list[i], next) if len(cand_amap) == 0: attach_fail[i] = True cand_mols.append(mol_list[i]) cand_batch.append(i) one_atom_attach.append(-1) intersection.append([]) new_atoms.append([]) else: valid_cand = 0 for amap in cand_amap: amap_len = len(amap) iter_atoms = [v for v in amap.values()] ctr_mol = deepcopy(mol_list[i]) cand_mol, amap1 = attach(ctr_mol, next, amap) if sanitize(deepcopy(cand_mol)) is None: continue smiles = get_smiles(cand_mol) cand_smiles.append(smiles) cand_mols.append(cand_mol) cand_batch.append(i) new_atoms.append([v for v in amap1.values()]) one_atom_attach.append(amap_len) intersection.append(iter_atoms) valid_cand+=1 if valid_cand==0: attach_fail[i] = True cand_mols.append(mol_list[i]) cand_batch.append(i) one_atom_attach.append(-1) intersection.append([]) new_atoms.append([]) cand_batch = torch.tensor(cand_batch) one_atom_attach = torch.tensor(one_atom_attach) == 1 return cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail if __name__ == "__main__": import sys from mol_tree import MolTree lg = rdkit.RDLogger.logger() lg.setLevel(rdkit.RDLogger.CRITICAL) smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1", "O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"] mol_tree = MolTree("C") assert len(mol_tree.nodes) > 0 def count(): cnt, n = 0, 0 for s in sys.stdin: s = s.split()[0] tree = MolTree(s) tree.recover() tree.assemble() for node in tree.nodes: cnt += len(node.cands) n += len(tree.nodes) # print cnt * 1.0 / n count()