import argparse from pathlib import Path import os import re import torch import torch.nn.functional as F from tqdm import tqdm from datasets import Dataset, concatenate_datasets import pdb # Import necessary classes from your provided scripts # Ensure that smiles_train.py and the tokenizer are in the Python path from train import MDLMLightningModule from peptide_analyzer import PeptideAnalyzer from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer def peptide_bond_mask(smiles_list): """ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations of recognized bonds in the positions dictionary and 0 elsewhere. Args: smiles_list: List of peptide SMILES strings (batch of SMILES strings). Returns: np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. """ # Initialize the batch mask batch_size = len(smiles_list) max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros bond_patterns = [ (r'OC\(=O\)', 'ester'), (r'N\(C\)C\(=O\)', 'n_methyl'), (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds (r'NC\(=O\)', 'peptide'), # Regular peptide bonds (r'C\(=O\)N\(C\)', 'n_methyl'), (r'C\(=O\)N[12]?', 'peptide') ] for batch_idx, smiles in enumerate(smiles_list): positions = [] used = set() # Identify bonds for pattern, bond_type in bond_patterns: for match in re.finditer(pattern, smiles): if not any(p in range(match.start(), match.end()) for p in used): positions.append({ 'start': match.start(), 'end': match.end(), 'type': bond_type, 'pattern': match.group() }) used.update(range(match.start(), match.end())) # Update the mask for the current SMILES for pos in positions: mask[batch_idx, pos['start']:pos['end']] = 1 return mask def peptide_token_mask(smiles_list, token_lists): """ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens where any part of the token overlaps with a peptide bond, and 0 elsewhere. Args: smiles_list: List of peptide SMILES strings (batch of SMILES strings). token_lists: List of tokenized SMILES strings (split into tokens). Returns: np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. """ # Initialize the batch mask batch_size = len(smiles_list) token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros atomwise_masks = peptide_bond_mask(smiles_list) for batch_idx, atomwise_mask in enumerate(atomwise_masks): token_seq = token_lists[batch_idx] atom_idx = 0 for token_idx, token in enumerate(token_seq): if token_idx != 0 and token_idx != len(token_seq) - 1: if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: tokenized_masks[batch_idx][token_idx] = 1 atom_idx += len(token) return tokenized_masks def generate_and_filter_batch(model, tokenizer, peptide_analyzer, seq_len, batch_size, n_steps, temperature, device): """ Generates a single batch of SMILES, filters them for validity, and returns the valid ones along with their original corresponding noise tensors (x0) and final token tensors (x1). Args: model (MDLMLightningModule): The trained PyTorch Lightning model. tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. peptide_analyzer (PeptideAnalyzer): The analyzer to validate peptides. seq_len (int): The sequence length for this batch. batch_size (int): The number of samples to generate in this batch. n_steps (int): The number of steps for the flow matching process. temperature (float): The sampling temperature. device (str): The device to run generation on ('cuda' or 'cpu'). Returns: tuple[list[str], list[torch.Tensor], list[torch.Tensor]]: A tuple containing: - A list of valid, generated peptide SMILES strings. - A list of the corresponding x0 tensors (noise). - A list of the corresponding x1 tensors (final generated tokens). """ # 1. Start with a tensor of random tokens (pure noise at t=0) x0 = torch.randint( 0, model.model.vocab_size, (batch_size, seq_len), device=device ) x = x0.clone() # 2. Define the time schedule for the forward process (0.0 to 1.0) time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device) # 3. Iteratively follow the flow from noise to data with torch.no_grad(): for i in range(n_steps): t_curr = time_steps[i] # Prepare the current timestep tensor for the model t_tensor = torch.full((batch_size,), t_curr, device=device) # Get the model's prediction for the final clean sequence (at t=1) logits = model(x, t_tensor) if temperature > 0: logits = logits / temperature pred_x1 = torch.argmax(logits, dim=-1) if i == n_steps - 1: x = pred_x1 break # --- Construct the next state x_{t_next} --- t_next = time_steps[i+1] noise_prob = 1.0 - t_next mask = torch.rand(x.shape, device=device) < noise_prob noise = torch.randint(0, model.model.vocab_size, x.shape, device=device) x = torch.where(mask, noise, pred_x1) generated_sequences = tokenizer.batch_decode(x) # 5. Analyze the validity and collect valid (SMILES, x0, x1) triplets valid_smiles = [] valid_x0s = [] valid_x1s = [] for i, seq in enumerate(generated_sequences): if peptide_analyzer.is_peptide(seq): valid_smiles.append(seq) valid_x0s.append(x0[i]) valid_x1s.append(x[i]) # Store the final token tensor return valid_smiles, valid_x0s, valid_x1s def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) model = MDLMLightningModule.load_from_checkpoint( args.checkpoint_path, args=checkpoint["hyper_parameters"]["args"], tokenizer=tokenizer, strict=False ).to(device).eval() pa = PeptideAnalyzer() all_sources = [] all_targets = [] all_bonds = [] for length in range(args.max_length, args.min_length - 1, -1): print(f"\n--- Generating for length {length} ---") collected_for_len = 0 pbar = tqdm(total=args.num_sequences_per_length, desc=f"Length {length}") # Accumulators for the current "save batch" chunk_source, chunk_target, chunk_bond = [], [], [] max_batch_size = args.max_tokens_in_batch // length while collected_for_len < args.num_sequences_per_length: num_needed = args.num_sequences_per_length - collected_for_len gen_bsz = max_batch_size - len(chunk_target) if max_batch_size > len(chunk_target) else max_batch_size if gen_bsz == 0: print(f"Warning: Length {length} too long for token limit. Skipping.") break actual_bsz = min(num_needed, gen_bsz) smiles, x0s, x1s = generate_and_filter_batch( model, tokenizer, pa, length, actual_bsz, args.n_steps, args.temperature, device ) if smiles: tokens = tokenizer.get_token_split(x1s) b_masks = peptide_token_mask(smiles, tokens) chunk_source.extend([x.tolist() for x in x0s]) chunk_target.extend([x.tolist() for x in x1s]) chunk_bond.extend(b_masks.tolist()) collected_for_len += len(smiles) pbar.update(len(smiles)) # Check if current chunk hits the token limit, and if so, save it if len(chunk_target) == min(max_batch_size, args.num_sequences_per_length): all_sources.append(chunk_source) all_targets.append(chunk_target) all_bonds.append(chunk_bond) chunk_source, chunk_target, chunk_bond = [], [], [] pbar.close() all_data = Dataset.from_dict({ 'source_ids': all_sources, 'target_ids': all_targets, 'bond_mask': all_bonds }) print("\n--- Combining all generated data chunks ---") print(f"Total valid sequences collected: {len(all_data)}") print(f"Saving new rectified dataset to {args.output_dir}...") train_val = all_data.train_test_split(test_size=0.1, seed=42) final_split = train_val['train'].train_test_split(test_size=(1/9), seed=42) train_val['train'].save_to_disk(os.path.join(args.output_dir, 'train')) final_split['test'].save_to_disk(os.path.join(args.output_dir, 'validation')) train_val['test'].save_to_disk(os.path.join(args.output_dir, 'test')) print("\nDataset combination and saving complete.") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate rectified data couplings using a trained ReDi model for a range of lengths.") # --- Required Arguments --- parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the new rectified dataset.") # --- Generation Arguments --- parser.add_argument("--num_sequences_per_length", type=int, default=100, help="Number of valid sequences to generate for each length.") parser.add_argument("--min_length", type=int, default=4, help="Minimum sequence length to generate.") parser.add_argument("--max_length", type=int, default=1035, help="Maximum sequence length to generate (and padding length).") parser.add_argument("--max_tokens_in_batch", type=int, default=5200, help="Maximum number of tokens in a single generation batch (batch_size * seq_len).") parser.add_argument("--n_steps", type=int, default=100, help="Number of steps for the flow matching process.") parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity. Set to 0 for pure argmax.") # --- Environment Arguments --- parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") args = parser.parse_args() main(args)