import numpy as np import torch from torch.utils.data import DataLoader from datasets import load_from_disk, DatasetDict from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score import torch.nn as nn import optuna import os from typing import Dict, Any, Tuple, Optional import matplotlib.pyplot as plt from sklearn.metrics import ( f1_score, roc_auc_score, average_precision_score, precision_recall_curve, roc_curve ) import json import joblib import pandas as pd import time from lightning.pytorch import seed_everything seed_everything(1986) def infer_in_dim_from_unpooled_ds(ds) -> int: ex = ds[0] # ex["embedding"] is (L, H) list/array return int(len(ex["embedding"][0])) def load_split(dataset_path): ds = load_from_disk(dataset_path) if isinstance(ds, DatasetDict): return ds["train"], ds["val"] raise ValueError("Expected DatasetDict with 'train' and 'val' splits") def collate_unpooled(batch): # batch: list of dicts lengths = [int(x["length"]) for x in batch] Lmax = max(lengths) H = len(batch[0]["embedding"][0]) # 1280 X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32) M = torch.zeros(len(batch), Lmax, dtype=torch.bool) y = torch.tensor([x["label"] for x in batch], dtype=torch.float32) for i, x in enumerate(batch): emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L, H) L = emb.shape[0] X[i, :L] = emb if "attention_mask" in x: m = torch.tensor(x["attention_mask"], dtype=torch.bool) M[i, :L] = m[:L] else: M[i, :L] = True return X, M, y # ======================== Helper functions ========================================= def save_predictions_csv( out_dir: str, split_name: str, y_true: np.ndarray, y_prob: np.ndarray, threshold: float, sequences: Optional[np.ndarray] = None, ): os.makedirs(out_dir, exist_ok=True) df = pd.DataFrame({ "y_true": y_true.astype(int), "y_prob": y_prob.astype(float), "y_pred": (y_prob >= threshold).astype(int), }) if sequences is not None: df.insert(0, "sequence", sequences) df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False) def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray): os.makedirs(out_dir, exist_ok=True) # PR precision, recall, _ = precision_recall_curve(y_true, y_prob) plt.figure() plt.plot(recall, precision) plt.xlabel("Recall") plt.ylabel("Precision") plt.title("Precision-Recall Curve") plt.tight_layout() plt.savefig(os.path.join(out_dir, "pr_curve.png")) plt.close() # ROC fpr, tpr, _ = roc_curve(y_true, y_prob) plt.figure() plt.plot(fpr, tpr) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title("ROC Curve") plt.tight_layout() plt.savefig(os.path.join(out_dir, "roc_curve.png")) plt.close() # ======================== Shared OPTUNA training scheme ========================================= def best_f1_threshold(y_true, y_prob): p, r, thr = precision_recall_curve(y_true, y_prob) f1s = (2*p[:-1]*r[:-1])/(p[:-1]+r[:-1]+1e-12) i = int(np.nanargmax(f1s)) return float(thr[i]), float(f1s[i]) @torch.no_grad() def eval_probs(model, loader, device): model.eval() ys, ps = [], [] for X, M, y in loader: X, M = X.to(device), M.to(device) logits = model(X, M) prob = torch.sigmoid(logits).detach().cpu().numpy() ys.append(y.numpy()) ps.append(prob) return np.concatenate(ys), np.concatenate(ps) def train_one_epoch(model, loader, optim, criterion, device): model.train() for X, M, y in loader: X, M, y = X.to(device), M.to(device), y.to(device) optim.zero_grad(set_to_none=True) logits = model(X, M) loss = criterion(logits, y) loss.backward() optim.step() # ======================== MLP ========================================= # Still need mean pooling along lengths class MaskedMeanPool(nn.Module): def forward(self, X, M): # X: (B,L,H), M: (B,L) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom # (B,H) class MLPClassifier(nn.Module): def __init__(self, in_dim, hidden=512, dropout=0.1): super().__init__() self.pool = MaskedMeanPool() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1), ) def forward(self, X, M): z = self.pool(X, M) return self.net(z).squeeze(-1) # logits # ======================== CNN ========================================= # Treat 1280 dimensions as channels class CNNClassifier(nn.Module): def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): super().__init__() blocks = [] ch = in_ch for _ in range(layers): blocks += [ nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout), ] ch = c self.conv = nn.Sequential(*blocks) self.head = nn.Linear(c, 1) def forward(self, X, M): # X: (B,L,H) -> (B,H,L) Xc = X.transpose(1, 2) Y = self.conv(Xc).transpose(1, 2) # (B,L,C) # masked mean pool over L Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Y * Mf).sum(dim=1) / denom # (B,C) return self.head(pooled).squeeze(-1) # ========================== Transformer ==================================== class TransformerClassifier(nn.Module): def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): super().__init__() self.proj = nn.Linear(in_dim, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=ff, dropout=dropout, batch_first=True, activation="gelu" ) self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) self.head = nn.Linear(d_model, 1) def forward(self, X, M): # src_key_padding_mask: True = pad positions pad_mask = ~M Z = self.proj(X) # (B,L,d) Z = self.enc(Z, src_key_padding_mask=pad_mask) # (B,L,d) Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) pooled = (Z * Mf).sum(dim=1) / denom return self.head(pooled).squeeze(-1) # ========================== OPTUNA ==================================== def objective_nn(trial, model_name, train_ds, val_ds, device="cuda:0"): # hyperparams shared lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True) wd = trial.suggest_float("weight_decay", 1e-8, 1e-2, log=True) dropout = trial.suggest_float("dropout", 0.0, 0.5) batch_size = trial.suggest_categorical("batch_size", [16, 32, 64]) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_unpooled, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled, num_workers=4, pin_memory=True) in_dim = infer_in_dim_from_unpooled_ds(train_ds) if model_name == "mlp": hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048]) model = MLPClassifier(in_dim=in_dim, hidden=hidden, dropout=dropout) elif model_name == "cnn": c = trial.suggest_categorical("channels", [128, 256, 512]) k = trial.suggest_categorical("kernel", [3, 5, 7]) layers = trial.suggest_int("layers", 1, 4) model = CNNClassifier(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout) elif model_name == "transformer": d = trial.suggest_categorical("d_model", [128, 256, 384]) nhead = trial.suggest_categorical("nhead", [4, 8]) layers = trial.suggest_int("layers", 1, 4) ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536]) model = TransformerClassifier(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout) else: raise ValueError(model_name) model = model.to(device) # class imbalance handling ytr = np.asarray(train_ds["label"], dtype=np.int64) pos = ytr.sum() neg = len(ytr) - pos pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) best_f1 = -1.0 patience = 8 bad = 0 for epoch in range(1, 51): train_one_epoch(model, train_loader, optim, criterion, device) y_true, y_prob = eval_probs(model, val_loader, device) auc = roc_auc_score(y_true, y_prob) thr, f1 = best_f1_threshold(y_true, y_prob) trial.set_user_attr("val_auc", float(auc)) trial.set_user_attr("val_f1", float(f1)) trial.set_user_attr("val_thr", float(thr)) # prune trial.report(f1, epoch) if trial.should_prune(): raise optuna.TrialPruned() if f1 > best_f1 + 1e-4: best_f1 = f1 bad = 0 else: bad += 1 if bad >= patience: break return best_f1 def run_optuna_and_refit_nn(dataset_path: str, out_dir: str, model_name: str, n_trials: int = 50, device="cuda:0"): os.makedirs(out_dir, exist_ok=True) train_ds, val_ds = load_split(dataset_path) print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}") study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) study.optimize(lambda trial: objective_nn(trial, model_name, train_ds, val_ds, device=device), n_trials=n_trials) trials_df = study.trials_dataframe() trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False) best = study.best_trial best_params = dict(best.params) best_f1_optuna = float(best.value) best_auc_optuna = float(best.user_attrs.get("val_auc", np.nan)) best_thr = float(best.user_attrs.get("val_thr", 0.5)) in_dim = infer_in_dim_from_unpooled_ds(train_ds) # --- Refit best model --- batch_size = int(best_params.get("batch_size", 32)) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_unpooled, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled, num_workers=4, pin_memory=True) # Rebuild dropout = float(best_params.get("dropout", 0.1)) if model_name == "mlp": model = MLPClassifier( in_dim=in_dim, hidden=int(best_params["hidden"]), dropout=dropout, ) elif model_name == "cnn": model = CNNClassifier( in_ch=in_dim, c=int(best_params["channels"]), k=int(best_params["kernel"]), layers=int(best_params["layers"]), dropout=dropout, ) elif model_name == "transformer": model = TransformerClassifier( in_dim=in_dim, d_model=int(best_params["d_model"]), nhead=int(best_params["nhead"]), layers=int(best_params["layers"]), ff=int(best_params["ff"]), dropout=dropout, ) else: raise ValueError(model_name) model = model.to(device) # loss + optimizer ytr = np.asarray(train_ds["label"], dtype=np.int64) pos = ytr.sum() neg = len(ytr) - pos pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) lr = float(best_params["lr"]) wd = float(best_params["weight_decay"]) optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) # train longer with early stopping on AUC best_f1_seen, bad, patience = -1.0, 0, 12 best_state = None best_thr_seen = 0.5 best_auc_seen = -1.0 for epoch in range(1, 151): train_one_epoch(model, train_loader, optim, criterion, device) y_true, y_prob = eval_probs(model, val_loader, device) auc = roc_auc_score(y_true, y_prob) thr, f1 = best_f1_threshold(y_true, y_prob) if f1 > best_f1_seen + 1e-4: best_f1_seen = f1 best_thr_seen = thr best_auc_seen = auc bad = 0 best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} else: bad += 1 if bad >= patience: break if best_state is not None: model.load_state_dict(best_state) # final preds + threshold picked on val y_true_val, y_prob_val = eval_probs(model, val_loader, device) best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val) # save model model_path = os.path.join(out_dir, "best_model.pt") torch.save({"state_dict": model.state_dict(), "best_params": best_params}, model_path) # train preds y_true_tr, y_prob_tr = eval_probs(model, DataLoader(train_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled, num_workers=4, pin_memory=True), device) save_predictions_csv(out_dir, "train", y_true_tr, y_prob_tr, best_thr_final, sequences=np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None) save_predictions_csv(out_dir, "val", y_true_val, y_prob_val, best_thr_final, sequences=np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None) plot_curves(out_dir, y_true_val, y_prob_val) summary = [ "=" * 72, f"MODEL: {model_name}", # Optuna results (objective = F1) f"Best Optuna F1 (objective): {best_f1_optuna:.4f}", f"Best Optuna AUC (val, recorded): {best_auc_optuna:.4f}", f"Best Optuna threshold (val): {best_thr:.4f}", # Refit results f"Refit best AUC (val): {best_auc_seen:.4f}", f"Refit best F1@thr (val): {best_f1_final:.4f} at thr={best_thr_final:.4f}", "Best params:", json.dumps(best_params, indent=2), f"Saved model: {model_path}", "=" * 72, ] with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f: f.write("\n".join(summary)) print("\n".join(summary)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--dataset_path", type=str, required=True) parser.add_argument("--out_dir", type=str, required=True) parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True) parser.add_argument("--n_trials", type=int, default=50) args = parser.parse_args() if args.model in ["mlp", "cnn", "transformer"]: run_optuna_and_refit_nn(args.dataset_path, args.out_dir, args.model, args.n_trials, device="cuda:0")