import os, json from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import optuna from datasets import load_from_disk, DatasetDict from scipy.stats import spearmanr from lightning.pytorch import seed_everything seed_everything(1986) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float: rho = spearmanr(y_true, y_pred).correlation if rho is None or np.isnan(rho): return 0.0 return float(rho) # ----------------------------- # Affinity class thresholds (final spec) # High >= 9 ; Moderate 7-9 ; Low < 7 # 0=High, 1=Moderate, 2=Low # ----------------------------- def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor: high = y >= 9.0 low = y < 7.0 mid = ~(high | low) cls = torch.zeros_like(y, dtype=torch.long) cls[mid] = 1 cls[low] = 2 return cls # ----------------------------- # Load paired DatasetDict # ----------------------------- def load_split_paired(path: str): dd = load_from_disk(path) if not isinstance(dd, DatasetDict): raise ValueError(f"Expected DatasetDict at {path}") if "train" not in dd or "val" not in dd: raise ValueError(f"DatasetDict missing train/val at {path}") return dd["train"], dd["val"] # ----------------------------- # Collate: pooled paired # ----------------------------- def collate_pair_pooled(batch): Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht) Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb) y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) return Pt, Pb, y # ----------------------------- # Collate: unpooled paired # ----------------------------- def collate_pair_unpooled(batch): B = len(batch) Ht = len(batch[0]["target_embedding"][0]) Hb = len(batch[0]["binder_embedding"][0]) Lt_max = max(int(x["target_length"]) for x in batch) Lb_max = max(int(x["binder_length"]) for x in batch) Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32) Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32) Mt = torch.zeros(B, Lt_max, dtype=torch.bool) Mb = torch.zeros(B, Lb_max, dtype=torch.bool) y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) for i, x in enumerate(batch): t = torch.tensor(x["target_embedding"], dtype=torch.float32) b = torch.tensor(x["binder_embedding"], dtype=torch.float32) lt, lb = t.shape[0], b.shape[0] Pt[i, :lt] = t Pb[i, :lb] = b Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool) Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool) return Pt, Mt, Pb, Mb, y # ----------------------------- # Cross-attention models # ----------------------------- class CrossAttnPooled(nn.Module): """ pooled vectors -> treat as single-token sequences for cross attention """ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def forward(self, t_vec, b_vec): # (B,Ht),(B,Hb) t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H) b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H) for L in self.layers: t_attn, _ = L["attn_tb"](t, b, b) t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) b_attn, _ = L["attn_bt"](b, t, t) b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) t0 = t[0] b0 = b[0] z = torch.cat([t0, b0], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) class CrossAttnUnpooled(nn.Module): """ token sequences with masks; alternating cross attention. """ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def masked_mean(self, X, M): Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom def forward(self, T, Mt, B, Mb): # T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb) T = self.t_proj(T) Bx = self.b_proj(B) kp_t = ~Mt # key_padding_mask True = pad kp_b = ~Mb for L in self.layers: # T attends to B T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) T = L["n1t"](T + T_attn) T = L["n2t"](T + L["fft"](T)) # B attends to T B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) Bx = L["n1b"](Bx + B_attn) Bx = L["n2b"](Bx + L["ffb"](Bx)) t_pool = self.masked_mean(T, Mt) b_pool = self.masked_mean(Bx, Mb) z = torch.cat([t_pool, b_pool], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) # ----------------------------- # Train/eval # ----------------------------- @torch.no_grad() def eval_spearman_pooled(model, loader): model.eval() ys, ps = [], [] for t, b, y in loader: t = t.to(DEVICE, non_blocking=True) b = b.to(DEVICE, non_blocking=True) pred, _ = model(t, b) ys.append(y.numpy()) ps.append(pred.detach().cpu().numpy()) return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) @torch.no_grad() def eval_spearman_unpooled(model, loader): model.eval() ys, ps = [], [] for T, Mt, B, Mb, y in loader: T = T.to(DEVICE, non_blocking=True) Mt = Mt.to(DEVICE, non_blocking=True) B = B.to(DEVICE, non_blocking=True) Mb = Mb.to(DEVICE, non_blocking=True) pred, _ = model(T, Mt, B, Mb) ys.append(y.numpy()) ps.append(pred.detach().cpu().numpy()) return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): model.train() for t, b, y in loader: t = t.to(DEVICE, non_blocking=True) b = b.to(DEVICE, non_blocking=True) y = y.to(DEVICE, non_blocking=True) y_cls = affinity_to_class_tensor(y) opt.zero_grad(set_to_none=True) pred, logits = model(t, b) L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) L.backward() if clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), clip) opt.step() def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): model.train() for T, Mt, B, Mb, y in loader: T = T.to(DEVICE, non_blocking=True) Mt = Mt.to(DEVICE, non_blocking=True) B = B.to(DEVICE, non_blocking=True) Mb = Mb.to(DEVICE, non_blocking=True) y = y.to(DEVICE, non_blocking=True) y_cls = affinity_to_class_tensor(y) opt.zero_grad(set_to_none=True) pred, logits = model(T, Mt, B, Mb) L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) L.backward() if clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), clip) opt.step() # ----------------------------- # Optuna objective # ----------------------------- def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float: lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True) wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True) dropout = trial.suggest_float("dropout", 0.0, 0.4) hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768]) n_heads = trial.suggest_categorical("n_heads", [4, 8]) n_layers = trial.suggest_int("n_layers", 1, 4) cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True) batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128]) # infer dims from first row if mode == "pooled": Ht = len(train_ds[0]["target_embedding"]) Hb = len(train_ds[0]["binder_embedding"]) collate = collate_pair_pooled model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) eval_fn = eval_spearman_pooled train_fn = train_one_epoch_pooled else: Ht = len(train_ds[0]["target_embedding"][0]) Hb = len(train_ds[0]["binder_embedding"][0]) collate = collate_pair_unpooled model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) eval_fn = eval_spearman_unpooled train_fn = train_one_epoch_unpooled opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) loss_reg = nn.MSELoss() loss_cls = nn.CrossEntropyLoss() best = -1e9 bad = 0 patience = 10 for ep in range(1, 61): train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) rho = eval_fn(model, val_loader) trial.report(rho, ep) if trial.should_prune(): raise optuna.TrialPruned() if rho > best + 1e-6: best = rho bad = 0 else: bad += 1 if bad >= patience: break return float(best) # ----------------------------- # Run: optuna + refit best # ----------------------------- def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50): out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) train_ds, val_ds = load_split_paired(dataset_path) print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}") study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials) study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False) best = study.best_trial best_params = dict(best.params) # refit longer lr = float(best_params["lr"]) wd = float(best_params["weight_decay"]) dropout = float(best_params["dropout"]) hidden = int(best_params["hidden_dim"]) n_heads = int(best_params["n_heads"]) n_layers = int(best_params["n_layers"]) cls_w = float(best_params["cls_weight"]) batch = int(best_params["batch_size"]) loss_reg = nn.MSELoss() loss_cls = nn.CrossEntropyLoss() if mode == "pooled": Ht = len(train_ds[0]["target_embedding"]) Hb = len(train_ds[0]["binder_embedding"]) model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) collate = collate_pair_pooled train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) eval_fn = eval_spearman_pooled train_fn = train_one_epoch_pooled else: Ht = len(train_ds[0]["target_embedding"][0]) Hb = len(train_ds[0]["binder_embedding"][0]) model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) collate = collate_pair_unpooled train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) eval_fn = eval_spearman_unpooled train_fn = train_one_epoch_unpooled opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) best_rho = -1e9 bad = 0 patience = 20 best_state = None for ep in range(1, 201): train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) rho = eval_fn(model, val_loader) if rho > best_rho + 1e-6: best_rho = rho bad = 0 best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} else: bad += 1 if bad >= patience: break if best_state is not None: model.load_state_dict(best_state) # save torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt") with open(out_dir / "best_params.json", "w") as f: json.dump(best_params, f, indent=2) print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}") if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)") ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True) ap.add_argument("--out_dir", type=str, required=True) ap.add_argument("--n_trials", type=int, default=50) args = ap.parse_args() run( dataset_path=args.dataset_path, out_dir=args.out_dir, mode=args.mode, n_trials=args.n_trials, )