|
|
import argparse |
|
|
from pathlib import Path |
|
|
import os |
|
|
import re |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from tqdm import tqdm |
|
|
from datasets import Dataset, concatenate_datasets |
|
|
import pdb |
|
|
|
|
|
|
|
|
|
|
|
from train import MDLMLightningModule |
|
|
from peptide_analyzer import PeptideAnalyzer |
|
|
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
|
|
|
def peptide_bond_mask(smiles_list): |
|
|
""" |
|
|
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations |
|
|
of recognized bonds in the positions dictionary and 0 elsewhere. |
|
|
|
|
|
Args: |
|
|
smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. |
|
|
""" |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
max_seq_length = 1035 |
|
|
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) |
|
|
|
|
|
bond_patterns = [ |
|
|
(r'OC\(=O\)', 'ester'), |
|
|
(r'N\(C\)C\(=O\)', 'n_methyl'), |
|
|
(r'N[12]C\(=O\)', 'peptide'), |
|
|
(r'NC\(=O\)', 'peptide'), |
|
|
(r'C\(=O\)N\(C\)', 'n_methyl'), |
|
|
(r'C\(=O\)N[12]?', 'peptide') |
|
|
] |
|
|
|
|
|
for batch_idx, smiles in enumerate(smiles_list): |
|
|
positions = [] |
|
|
used = set() |
|
|
|
|
|
|
|
|
for pattern, bond_type in bond_patterns: |
|
|
for match in re.finditer(pattern, smiles): |
|
|
if not any(p in range(match.start(), match.end()) for p in used): |
|
|
positions.append({ |
|
|
'start': match.start(), |
|
|
'end': match.end(), |
|
|
'type': bond_type, |
|
|
'pattern': match.group() |
|
|
}) |
|
|
used.update(range(match.start(), match.end())) |
|
|
|
|
|
|
|
|
for pos in positions: |
|
|
mask[batch_idx, pos['start']:pos['end']] = 1 |
|
|
|
|
|
return mask |
|
|
|
|
|
def peptide_token_mask(smiles_list, token_lists): |
|
|
""" |
|
|
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens |
|
|
where any part of the token overlaps with a peptide bond, and 0 elsewhere. |
|
|
|
|
|
Args: |
|
|
smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
|
|
token_lists: List of tokenized SMILES strings (split into tokens). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. |
|
|
""" |
|
|
|
|
|
batch_size = len(smiles_list) |
|
|
token_seq_length = max(len(tokens) for tokens in token_lists) |
|
|
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) |
|
|
atomwise_masks = peptide_bond_mask(smiles_list) |
|
|
|
|
|
|
|
|
for batch_idx, atomwise_mask in enumerate(atomwise_masks): |
|
|
token_seq = token_lists[batch_idx] |
|
|
atom_idx = 0 |
|
|
|
|
|
for token_idx, token in enumerate(token_seq): |
|
|
if token_idx != 0 and token_idx != len(token_seq) - 1: |
|
|
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: |
|
|
tokenized_masks[batch_idx][token_idx] = 1 |
|
|
atom_idx += len(token) |
|
|
|
|
|
return tokenized_masks |
|
|
|
|
|
def generate_and_filter_batch(model, tokenizer, peptide_analyzer, seq_len, batch_size, n_steps, temperature, device): |
|
|
""" |
|
|
Generates a single batch of SMILES, filters them for validity, and returns the valid ones |
|
|
along with their original corresponding noise tensors (x0) and final token tensors (x1). |
|
|
|
|
|
Args: |
|
|
model (MDLMLightningModule): The trained PyTorch Lightning model. |
|
|
tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. |
|
|
peptide_analyzer (PeptideAnalyzer): The analyzer to validate peptides. |
|
|
seq_len (int): The sequence length for this batch. |
|
|
batch_size (int): The number of samples to generate in this batch. |
|
|
n_steps (int): The number of steps for the flow matching process. |
|
|
temperature (float): The sampling temperature. |
|
|
device (str): The device to run generation on ('cuda' or 'cpu'). |
|
|
|
|
|
Returns: |
|
|
tuple[list[str], list[torch.Tensor], list[torch.Tensor]]: A tuple containing: |
|
|
- A list of valid, generated peptide SMILES strings. |
|
|
- A list of the corresponding x0 tensors (noise). |
|
|
- A list of the corresponding x1 tensors (final generated tokens). |
|
|
""" |
|
|
|
|
|
x0 = torch.randint( |
|
|
0, |
|
|
model.model.vocab_size, |
|
|
(batch_size, seq_len), |
|
|
device=device |
|
|
) |
|
|
x = x0.clone() |
|
|
|
|
|
|
|
|
time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(n_steps): |
|
|
t_curr = time_steps[i] |
|
|
|
|
|
t_tensor = torch.full((batch_size,), t_curr, device=device) |
|
|
|
|
|
|
|
|
logits = model(x, t_tensor) |
|
|
if temperature > 0: |
|
|
logits = logits / temperature |
|
|
|
|
|
pred_x1 = torch.argmax(logits, dim=-1) |
|
|
|
|
|
if i == n_steps - 1: |
|
|
x = pred_x1 |
|
|
break |
|
|
|
|
|
|
|
|
t_next = time_steps[i+1] |
|
|
noise_prob = 1.0 - t_next |
|
|
mask = torch.rand(x.shape, device=device) < noise_prob |
|
|
noise = torch.randint(0, model.model.vocab_size, x.shape, device=device) |
|
|
x = torch.where(mask, noise, pred_x1) |
|
|
|
|
|
generated_sequences = tokenizer.batch_decode(x) |
|
|
|
|
|
|
|
|
valid_smiles = [] |
|
|
valid_x0s = [] |
|
|
valid_x1s = [] |
|
|
for i, seq in enumerate(generated_sequences): |
|
|
if peptide_analyzer.is_peptide(seq): |
|
|
valid_smiles.append(seq) |
|
|
valid_x0s.append(x0[i]) |
|
|
valid_x1s.append(x[i]) |
|
|
|
|
|
return valid_smiles, valid_x0s, valid_x1s |
|
|
|
|
|
|
|
|
def main(args): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) |
|
|
checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) |
|
|
model = MDLMLightningModule.load_from_checkpoint( |
|
|
args.checkpoint_path, args=checkpoint["hyper_parameters"]["args"], |
|
|
tokenizer=tokenizer, strict=False |
|
|
).to(device).eval() |
|
|
pa = PeptideAnalyzer() |
|
|
|
|
|
all_sources = [] |
|
|
all_targets = [] |
|
|
all_bonds = [] |
|
|
|
|
|
for length in range(args.max_length, args.min_length - 1, -1): |
|
|
print(f"\n--- Generating for length {length} ---") |
|
|
|
|
|
collected_for_len = 0 |
|
|
pbar = tqdm(total=args.num_sequences_per_length, desc=f"Length {length}") |
|
|
|
|
|
|
|
|
chunk_source, chunk_target, chunk_bond = [], [], [] |
|
|
max_batch_size = args.max_tokens_in_batch // length |
|
|
|
|
|
while collected_for_len < args.num_sequences_per_length: |
|
|
num_needed = args.num_sequences_per_length - collected_for_len |
|
|
|
|
|
gen_bsz = max_batch_size - len(chunk_target) if max_batch_size > len(chunk_target) else max_batch_size |
|
|
if gen_bsz == 0: |
|
|
print(f"Warning: Length {length} too long for token limit. Skipping.") |
|
|
break |
|
|
|
|
|
actual_bsz = min(num_needed, gen_bsz) |
|
|
|
|
|
smiles, x0s, x1s = generate_and_filter_batch( |
|
|
model, tokenizer, pa, length, actual_bsz, |
|
|
args.n_steps, args.temperature, device |
|
|
) |
|
|
|
|
|
if smiles: |
|
|
tokens = tokenizer.get_token_split(x1s) |
|
|
b_masks = peptide_token_mask(smiles, tokens) |
|
|
|
|
|
chunk_source.extend([x.tolist() for x in x0s]) |
|
|
chunk_target.extend([x.tolist() for x in x1s]) |
|
|
chunk_bond.extend(b_masks.tolist()) |
|
|
|
|
|
collected_for_len += len(smiles) |
|
|
pbar.update(len(smiles)) |
|
|
|
|
|
|
|
|
|
|
|
if len(chunk_target) == min(max_batch_size, args.num_sequences_per_length): |
|
|
all_sources.append(chunk_source) |
|
|
all_targets.append(chunk_target) |
|
|
all_bonds.append(chunk_bond) |
|
|
chunk_source, chunk_target, chunk_bond = [], [], [] |
|
|
|
|
|
pbar.close() |
|
|
|
|
|
all_data = Dataset.from_dict({ |
|
|
'source_ids': all_sources, |
|
|
'target_ids': all_targets, |
|
|
'bond_mask': all_bonds |
|
|
}) |
|
|
print("\n--- Combining all generated data chunks ---") |
|
|
print(f"Total valid sequences collected: {len(all_data)}") |
|
|
|
|
|
print(f"Saving new rectified dataset to {args.output_dir}...") |
|
|
train_val = all_data.train_test_split(test_size=0.1, seed=42) |
|
|
final_split = train_val['train'].train_test_split(test_size=(1/9), seed=42) |
|
|
|
|
|
train_val['train'].save_to_disk(os.path.join(args.output_dir, 'train')) |
|
|
final_split['test'].save_to_disk(os.path.join(args.output_dir, 'validation')) |
|
|
train_val['test'].save_to_disk(os.path.join(args.output_dir, 'test')) |
|
|
|
|
|
print("\nDataset combination and saving complete.") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Generate rectified data couplings using a trained ReDi model for a range of lengths.") |
|
|
|
|
|
|
|
|
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") |
|
|
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the new rectified dataset.") |
|
|
|
|
|
|
|
|
parser.add_argument("--num_sequences_per_length", type=int, default=100, help="Number of valid sequences to generate for each length.") |
|
|
parser.add_argument("--min_length", type=int, default=4, help="Minimum sequence length to generate.") |
|
|
parser.add_argument("--max_length", type=int, default=1035, help="Maximum sequence length to generate (and padding length).") |
|
|
parser.add_argument("--max_tokens_in_batch", type=int, default=5200, help="Maximum number of tokens in a single generation batch (batch_size * seq_len).") |
|
|
parser.add_argument("--n_steps", type=int, default=100, help="Number of steps for the flow matching process.") |
|
|
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity. Set to 0 for pure argmax.") |
|
|
|
|
|
|
|
|
parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") |
|
|
parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args) |
|
|
|
|
|
|