Pocket-Gen / models /data.py
Zaixi's picture
1
dcacefd
import torch
import numpy as np
import json, copy
import random
import glob
import csv
import os
import re
import Bio.PDB
import pickle
import torch.nn.functional as F
from tqdm import tqdm
from collections import defaultdict
from bindgen.utils import full_square_dist
RESTYPE_1to3 = {
"A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS", "Q": "GLN","E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE", "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO", "S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL",
}
ALPHABET = ['#', 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
ATOM_TYPES = [
'', 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
]
RES_ATOM14 = [
[''] * 14,
['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
]
class AntibodyComplexDataset():
def __init__(self, jsonl_file, cdr_type, L_target):
self.data = []
with open(jsonl_file) as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
entry = json.loads(line)
assert len(entry['antibody_coords']) == len(entry['antibody_seq'])
assert len(entry['antigen_coords']) == len(entry['antigen_seq'])
if entry['antibody_cdr'].count(cdr_type) <= 4:
continue
# paratope region
surface = torch.tensor(
[i for i,v in enumerate(entry['antibody_cdr']) if v in cdr_type]
)
entry['binder_surface'] = surface
entry['binder_seq'] = ''.join([entry['antibody_seq'][i] for i in surface.tolist()])
entry['binder_coords'] = torch.tensor(entry['antibody_coords'])[surface]
entry['binder_atypes'] = torch.tensor(
[[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['binder_seq']]
)
mask = (entry['binder_coords'].norm(dim=-1) > 1e-6).long()
entry['binder_atypes'] *= mask
# Create target
entry['target_seq'] = entry['antigen_seq']
entry['target_coords'] = torch.tensor(entry['antigen_coords'])
entry['target_atypes'] = torch.tensor(
[[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['target_seq']]
)
mask = (entry['target_coords'].norm(dim=-1) > 1e-6).long()
entry['target_atypes'] *= mask
# Find target surface
dist, _ = full_square_dist(
entry['target_coords'][None,...],
entry['binder_coords'][None,...],
entry['target_atypes'][None,...],
entry['binder_atypes'][None,...],
contact=True
)
K = min(len(dist[0]), L_target)
epitope = dist[0].amin(dim=-1).topk(k=K, largest=False).indices
entry['target_surface'] = torch.sort(epitope).values
if len(entry['binder_coords']) > 4 and len(entry['target_coords']) > 4 and entry['antibody_cdr'].count('001') <= 1:
self.data.append(entry)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class ComplexLoader():
def __init__(self, dataset, batch_tokens):
self.dataset = dataset
self.size = len(dataset)
self.lengths = [len(dataset[i]['binder_seq']) for i in range(self.size)]
self.batch_tokens = batch_tokens
sorted_ix = np.argsort(self.lengths)
# Cluster into batches of similar sizes
clusters, batch = [], []
for ix in sorted_ix:
size = self.lengths[ix]
batch.append(ix)
if size * (len(batch) + 1) > self.batch_tokens:
clusters.append(batch)
batch = []
self.clusters = clusters
if len(batch) > 0:
clusters.append(batch)
def __len__(self):
return len(self.clusters)
def __iter__(self):
np.random.shuffle(self.clusters)
for b_idx in self.clusters:
batch = [self.dataset[i] for i in b_idx]
yield batch
def make_batch_from_seq(batch):
B = len(batch)
L_max = max([len(seq) for seq in batch])
S = np.zeros([B, L_max], dtype=np.int32)
mask = np.zeros([B, L_max], dtype=np.float32)
for i,seq in enumerate(batch):
l = len(seq)
indices = np.asarray([ALPHABET.index(a) for a in seq], dtype=np.int32)
S[i, :l] = indices
mask[i, :l] = 1.
S = torch.from_numpy(S).long().to('cuda:1')
mask = torch.from_numpy(mask).float().to('cuda:1')
return S, mask
def featurize(batch, name):
B = len(batch)
L_max = max([len(b[name + "_seq"]) for b in batch])
X = torch.zeros([B, L_max, 14, 3])
S = torch.zeros([B, L_max]).long()
A = torch.zeros([B, L_max, 14]).long()
V = torch.zeros([B, L_max, 12])
# Build the batch
for i, b in enumerate(batch):
l = len(b[name + '_seq'])
X[i,:l] = b[name + '_coords']
A[i,:l] = b[name + '_atypes']
V[i,:l] = b[name + '_dihedrals'] if name + '_dihedrals' in b else 0
indices = torch.tensor([ALPHABET.index(a) for a in b[name + '_seq']])
S[i,:l] = indices
return X.to('cuda:1'), S.to('cuda:1'), A.to('cuda:1'), V.to('cuda:1')
def make_batch(batch):
target = featurize(batch, 'target')
binder = featurize(batch, 'binder')
surface = ([b['binder_surface'] for b in batch], [b['target_surface'] for b in batch])
return binder, target, surface