AReUReDi / smiles /new_coupling.py
Tong Chen
add files
295b1cd
import argparse
from pathlib import Path
import os
import re
import torch
import torch.nn.functional as F
from tqdm import tqdm
from datasets import Dataset, concatenate_datasets
import pdb
# Import necessary classes from your provided scripts
# Ensure that smiles_train.py and the tokenizer are in the Python path
from train import MDLMLightningModule
from peptide_analyzer import PeptideAnalyzer
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
def peptide_bond_mask(smiles_list):
"""
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
of recognized bonds in the positions dictionary and 0 elsewhere.
Args:
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
Returns:
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
"""
# Initialize the batch mask
batch_size = len(smiles_list)
max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
bond_patterns = [
(r'OC\(=O\)', 'ester'),
(r'N\(C\)C\(=O\)', 'n_methyl'),
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
(r'C\(=O\)N\(C\)', 'n_methyl'),
(r'C\(=O\)N[12]?', 'peptide')
]
for batch_idx, smiles in enumerate(smiles_list):
positions = []
used = set()
# Identify bonds
for pattern, bond_type in bond_patterns:
for match in re.finditer(pattern, smiles):
if not any(p in range(match.start(), match.end()) for p in used):
positions.append({
'start': match.start(),
'end': match.end(),
'type': bond_type,
'pattern': match.group()
})
used.update(range(match.start(), match.end()))
# Update the mask for the current SMILES
for pos in positions:
mask[batch_idx, pos['start']:pos['end']] = 1
return mask
def peptide_token_mask(smiles_list, token_lists):
"""
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
Args:
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
token_lists: List of tokenized SMILES strings (split into tokens).
Returns:
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
"""
# Initialize the batch mask
batch_size = len(smiles_list)
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
atomwise_masks = peptide_bond_mask(smiles_list)
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
token_seq = token_lists[batch_idx]
atom_idx = 0
for token_idx, token in enumerate(token_seq):
if token_idx != 0 and token_idx != len(token_seq) - 1:
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
tokenized_masks[batch_idx][token_idx] = 1
atom_idx += len(token)
return tokenized_masks
def generate_and_filter_batch(model, tokenizer, peptide_analyzer, seq_len, batch_size, n_steps, temperature, device):
"""
Generates a single batch of SMILES, filters them for validity, and returns the valid ones
along with their original corresponding noise tensors (x0) and final token tensors (x1).
Args:
model (MDLMLightningModule): The trained PyTorch Lightning model.
tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training.
peptide_analyzer (PeptideAnalyzer): The analyzer to validate peptides.
seq_len (int): The sequence length for this batch.
batch_size (int): The number of samples to generate in this batch.
n_steps (int): The number of steps for the flow matching process.
temperature (float): The sampling temperature.
device (str): The device to run generation on ('cuda' or 'cpu').
Returns:
tuple[list[str], list[torch.Tensor], list[torch.Tensor]]: A tuple containing:
- A list of valid, generated peptide SMILES strings.
- A list of the corresponding x0 tensors (noise).
- A list of the corresponding x1 tensors (final generated tokens).
"""
# 1. Start with a tensor of random tokens (pure noise at t=0)
x0 = torch.randint(
0,
model.model.vocab_size,
(batch_size, seq_len),
device=device
)
x = x0.clone()
# 2. Define the time schedule for the forward process (0.0 to 1.0)
time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device)
# 3. Iteratively follow the flow from noise to data
with torch.no_grad():
for i in range(n_steps):
t_curr = time_steps[i]
# Prepare the current timestep tensor for the model
t_tensor = torch.full((batch_size,), t_curr, device=device)
# Get the model's prediction for the final clean sequence (at t=1)
logits = model(x, t_tensor)
if temperature > 0:
logits = logits / temperature
pred_x1 = torch.argmax(logits, dim=-1)
if i == n_steps - 1:
x = pred_x1
break
# --- Construct the next state x_{t_next} ---
t_next = time_steps[i+1]
noise_prob = 1.0 - t_next
mask = torch.rand(x.shape, device=device) < noise_prob
noise = torch.randint(0, model.model.vocab_size, x.shape, device=device)
x = torch.where(mask, noise, pred_x1)
generated_sequences = tokenizer.batch_decode(x)
# 5. Analyze the validity and collect valid (SMILES, x0, x1) triplets
valid_smiles = []
valid_x0s = []
valid_x1s = []
for i, seq in enumerate(generated_sequences):
if peptide_analyzer.is_peptide(seq):
valid_smiles.append(seq)
valid_x0s.append(x0[i])
valid_x1s.append(x[i]) # Store the final token tensor
return valid_smiles, valid_x0s, valid_x1s
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path)
checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False)
model = MDLMLightningModule.load_from_checkpoint(
args.checkpoint_path, args=checkpoint["hyper_parameters"]["args"],
tokenizer=tokenizer, strict=False
).to(device).eval()
pa = PeptideAnalyzer()
all_sources = []
all_targets = []
all_bonds = []
for length in range(args.max_length, args.min_length - 1, -1):
print(f"\n--- Generating for length {length} ---")
collected_for_len = 0
pbar = tqdm(total=args.num_sequences_per_length, desc=f"Length {length}")
# Accumulators for the current "save batch"
chunk_source, chunk_target, chunk_bond = [], [], []
max_batch_size = args.max_tokens_in_batch // length
while collected_for_len < args.num_sequences_per_length:
num_needed = args.num_sequences_per_length - collected_for_len
gen_bsz = max_batch_size - len(chunk_target) if max_batch_size > len(chunk_target) else max_batch_size
if gen_bsz == 0:
print(f"Warning: Length {length} too long for token limit. Skipping.")
break
actual_bsz = min(num_needed, gen_bsz)
smiles, x0s, x1s = generate_and_filter_batch(
model, tokenizer, pa, length, actual_bsz,
args.n_steps, args.temperature, device
)
if smiles:
tokens = tokenizer.get_token_split(x1s)
b_masks = peptide_token_mask(smiles, tokens)
chunk_source.extend([x.tolist() for x in x0s])
chunk_target.extend([x.tolist() for x in x1s])
chunk_bond.extend(b_masks.tolist())
collected_for_len += len(smiles)
pbar.update(len(smiles))
# Check if current chunk hits the token limit, and if so, save it
if len(chunk_target) == min(max_batch_size, args.num_sequences_per_length):
all_sources.append(chunk_source)
all_targets.append(chunk_target)
all_bonds.append(chunk_bond)
chunk_source, chunk_target, chunk_bond = [], [], []
pbar.close()
all_data = Dataset.from_dict({
'source_ids': all_sources,
'target_ids': all_targets,
'bond_mask': all_bonds
})
print("\n--- Combining all generated data chunks ---")
print(f"Total valid sequences collected: {len(all_data)}")
print(f"Saving new rectified dataset to {args.output_dir}...")
train_val = all_data.train_test_split(test_size=0.1, seed=42)
final_split = train_val['train'].train_test_split(test_size=(1/9), seed=42)
train_val['train'].save_to_disk(os.path.join(args.output_dir, 'train'))
final_split['test'].save_to_disk(os.path.join(args.output_dir, 'validation'))
train_val['test'].save_to_disk(os.path.join(args.output_dir, 'test'))
print("\nDataset combination and saving complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate rectified data couplings using a trained ReDi model for a range of lengths.")
# --- Required Arguments ---
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).")
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the new rectified dataset.")
# --- Generation Arguments ---
parser.add_argument("--num_sequences_per_length", type=int, default=100, help="Number of valid sequences to generate for each length.")
parser.add_argument("--min_length", type=int, default=4, help="Minimum sequence length to generate.")
parser.add_argument("--max_length", type=int, default=1035, help="Maximum sequence length to generate (and padding length).")
parser.add_argument("--max_tokens_in_batch", type=int, default=5200, help="Maximum number of tokens in a single generation batch (batch_size * seq_len).")
parser.add_argument("--n_steps", type=int, default=100, help="Number of steps for the flow matching process.")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity. Set to 0 for pure argmax.")
# --- Environment Arguments ---
parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.")
parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.")
args = parser.parse_args()
main(args)