import copy import random import torch import numpy as np from torch_geometric.data import Data, Batch from torch_scatter import scatter_sum # from torch_geometric.loader import DataLoader from torch.utils.data import Dataset FOLLOW_BATCH = ['protein_element', 'ligand_context_element', 'pos_real', 'pos_fake'] class ProteinLigandData(object): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs): instance = ProteinLigandData(**kwargs) if protein_dict is not None: for key, item in protein_dict.items(): instance['protein_' + key] = item if ligand_dict is not None: for key, item in ligand_dict.items(): instance['ligand_' + key] = item # instance['ligand_nbh_list'] = {i.item():[j.item() for k, j in enumerate(instance.ligand_bond_index[1]) if instance.ligand_bond_index[0, k].item() == i] for i in instance.ligand_bond_index[0]} return instance def batch_from_data_list(data_list): return Batch.from_data_list(data_list, follow_batch=['ligand_element', 'protein_element']) def torchify_dict(data): output = {} for k, v in data.items(): if isinstance(v, np.ndarray): output[k] = torch.from_numpy(v) else: output[k] = v return output def collate_mols(mol_dicts): data_batch = {} batch_size = len(mol_dicts) for key in ['protein_pos', 'protein_atom_feature', 'ligand_pos', 'ligand_atom_feature', 'protein_edit_residue', 'amino_acid', 'res_idx', 'residue_natoms', 'protein_atom_to_aa_type']: data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) # residue pos data_batch['residue_pos'] = \ torch.cat([torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0).unsqueeze(0) for key in ['pos_N', 'pos_CA', 'pos_C', 'pos_O']], dim=0).permute(1,0,2) # random mask residues for the second stage (one residue per protein) tmp = [] for mol_dict in mol_dicts: ind = torch.multinomial(mol_dict['protein_edit_residue'].float(), 1) selected = torch.zeros_like(mol_dict['protein_edit_residue'], dtype=bool) selected[ind] = 1 tmp.append(selected) data_batch['random_mask_residue'] = torch.cat(tmp, dim=0) # remove side chains for the masked atoms num_residues = len(data_batch['amino_acid']) data_batch['atom2residue'] = torch.repeat_interleave(torch.arange(num_residues), data_batch['residue_natoms']) index1 = torch.arange(len(data_batch['amino_acid']))[data_batch['random_mask_residue']] index2 = torch.arange(len(data_batch['amino_acid']))[data_batch['protein_edit_residue']] for key in ['protein_pos', 'protein_atom_feature']: tmp1, tmp2 = [], [] for k in range(num_residues): mask = data_batch['atom2residue'] == k if k in index1: tmp1.append(data_batch[key][mask][:4]) else: tmp1.append(data_batch[key][mask]) if k in index2: tmp2.append(data_batch[key][mask][:4]) else: tmp2.append(data_batch[key][mask]) data_batch[key] = torch.cat(tmp1, dim=0) data_batch[key + '_backbone'] = torch.cat(tmp2, dim=0) data_batch['residue_natoms'][data_batch['random_mask_residue']] = 4 data_batch['atom2residue'] = torch.repeat_interleave(torch.arange(len(data_batch['residue_natoms'])), data_batch['residue_natoms']) # follow batch for key in ['ligand_atom_feature', 'amino_acid']: repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts]) if key == 'amino_acid': data_batch['amino_acid_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats) else: data_batch['ligand_atom_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats) repeats = scatter_sum(data_batch['residue_natoms'], data_batch['amino_acid_batch'], dim=0) data_batch['protein_atom_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats) # backbone protein for the first stage data_batch['residue_natoms_backbone'] = copy.deepcopy(data_batch['residue_natoms']) data_batch['residue_natoms_backbone'][data_batch['protein_edit_residue']] = 4 repeats = scatter_sum(data_batch['residue_natoms_backbone'], data_batch['amino_acid_batch'], dim=0) data_batch['protein_atom_batch_backbone'] = torch.repeat_interleave(torch.arange(batch_size), repeats) data_batch['atom2residue_backbone'] = torch.repeat_interleave(torch.arange(len(data_batch['residue_natoms_backbone'])), data_batch['residue_natoms_backbone']) data_batch['protein_edit_atom'] = torch.repeat_interleave(data_batch['protein_edit_residue'], data_batch['residue_natoms'], dim=0) data_batch['protein_edit_atom_backbone'] = torch.repeat_interleave(data_batch['protein_edit_residue'], data_batch['residue_natoms_backbone'], dim=0) data_batch['random_mask_atom'] = torch.repeat_interleave(data_batch['random_mask_residue'], data_batch['residue_natoms'], dim=0) data_batch['edit_sidechain'] = copy.deepcopy(data_batch['protein_edit_atom']) data_batch['edit_backbone'] = copy.deepcopy(data_batch['protein_edit_atom']) index = torch.arange(len(data_batch['amino_acid']))[data_batch['protein_edit_residue']] for k in range(num_residues): mask = data_batch['atom2residue'] == k if k in index: data_mask1, data_mask2 = data_batch['edit_sidechain'][mask], data_batch['edit_backbone'][mask] data_mask1[:4], data_mask2[4:] = 0, 0 data_batch['edit_sidechain'][mask] = data_mask1 data_batch['edit_backbone'][mask] = data_mask2 return data_batch def collate_mols_block(mol_dicts, batch_converter): data_batch = {} batch_size = len(mol_dicts) for key in ['protein_pos', 'protein_atom_feature', 'protein_atom_name', 'protein_edit_residue', 'amino_acid', 'residue_natoms', 'protein_atom_to_aa_type', 'res_idx', 'ligand_element', 'ligand_bond_type']: data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) edge_num = torch.tensor([len(mol_dict['ligand_bond_type']) for mol_dict in mol_dicts]) ligand_atom_num = torch.tensor([len(mol_dict['ligand_element']) for mol_dict in mol_dicts]) data_batch['edge_batch'] = torch.repeat_interleave(torch.arange(batch_size), edge_num) data_batch['ligand_batch'] = torch.repeat_interleave(torch.arange(batch_size), ligand_atom_num) data_batch['ligand_bond_index'] = torch.cat([mol_dict['ligand_bond_index'] for mol_dict in mol_dicts], dim=1) # protein backbone pos data_batch['backbone_pos'] = torch.cat([torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0).unsqueeze(0) for key in ['pos_N', 'pos_CA', 'pos_C', 'pos_O']], dim=0).permute(1, 0, 2) # protein residue/feature for residue level encoding num_residues = len(data_batch['amino_acid']) data_batch['amino_acid_processed'] = copy.deepcopy(data_batch['amino_acid']) data_batch['amino_acid_processed'][data_batch['protein_edit_residue']] = 0 data_batch['atom2residue'] = torch.repeat_interleave(torch.arange(num_residues), data_batch['residue_natoms']) data_batch['residue_pos'] = torch.zeros(num_residues, 14, 3).to(data_batch['amino_acid'].device) # data_batch['residue_feat'] = torch.zeros(num_residues, 14, 38).to(data_batch['amino_acid'].device) index = torch.arange(num_residues)[data_batch['protein_edit_residue']] for k in range(num_residues): mask = data_batch['atom2residue'] == k data_batch['residue_pos'][k][:min(data_batch['residue_natoms'][k].item(), 14)] = data_batch['protein_pos'][mask][:min(data_batch['residue_natoms'][k].item(), 14)] ''' if k in index: data_batch['residue_feat'][k][:4] = data_batch['protein_atom_feature'][mask][:4] else: data_batch['residue_feat'][k][:data_batch['residue_natoms'][k]] = data_batch['protein_atom_feature'][mask] ''' # residue, ligand, protein atom follow batch repeats = torch.tensor([len(mol_dict['amino_acid']) for mol_dict in mol_dicts]) data_batch['amino_acid_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats) # ligand pos feat data_batch['ligand_natoms'] = torch.tensor([len(mol_dict['ligand_pos']) for mol_dict in mol_dicts]) max_ligand_atoms = max([len(mol_dict['ligand_pos']) for mol_dict in mol_dicts]) data_batch['ligand_pos'] = torch.zeros(batch_size, max_ligand_atoms, 3).to(data_batch['amino_acid'].device) data_batch['ligand_feat'] = torch.zeros(batch_size, max_ligand_atoms, 15).to(data_batch['amino_acid'].device) data_batch['ligand_mask'] = torch.zeros(batch_size, max_ligand_atoms).to(data_batch['amino_acid'].device) for b in range(batch_size): data_batch['ligand_pos'][b][:data_batch['ligand_natoms'][b]] = mol_dicts[b]['ligand_pos'] data_batch['ligand_feat'][b][:data_batch['ligand_natoms'][b]] = mol_dicts[b]['ligand_atom_feature'] data_batch['ligand_mask'][b, :data_batch['ligand_natoms'][b]] = 1 data_batch['edit_residue_num'] = torch.tensor([mol_dict['protein_edit_residue'].sum() for mol_dict in mol_dicts]).to(data_batch['amino_acid'].device) data_batch['seq'] = [('', mol_dict['seq']) for mol_dict in mol_dicts] _, _, data_batch['seq'] = batch_converter(data_batch['seq']) mask_id = 32 data_batch['full_seq_mask'] = torch.zeros_like(data_batch['seq']).bool() data_batch['r10_mask'] = torch.zeros_like(data_batch['seq']).bool() for b in range(batch_size): data_batch['seq'][b][mol_dicts[b]['full_seq_idx']+1] = mask_id data_batch['full_seq_mask'][b][mol_dicts[b]['full_seq_idx']+1] = True data_batch['r10_mask'][b][mol_dicts[b]['r10_idx'] + 1] = True data_batch['protein_filename'] = [mol_dict['whole_protein_name'] for mol_dict in mol_dicts] data_batch['pocket_filename'] = [mol_dict['protein_filename'] for mol_dict in mol_dicts] data_batch['ligand_filename'] = [mol_dict['ligand_filename'] for mol_dict in mol_dicts] return data_batch