|
|
import os |
|
|
import json |
|
|
import joblib |
|
|
import optuna |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Any, Tuple, Optional |
|
|
from datasets import load_from_disk, DatasetDict |
|
|
from sklearn.metrics import ( |
|
|
f1_score, roc_auc_score, average_precision_score, |
|
|
precision_recall_curve, roc_curve |
|
|
) |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
from sklearn.svm import SVC, LinearSVC |
|
|
from sklearn.calibration import CalibratedClassifierCV |
|
|
import torch |
|
|
import time |
|
|
import xgboost as xgb |
|
|
from lightning.pytorch import seed_everything |
|
|
import cupy as cp |
|
|
from cuml.svm import SVC as cuSVC |
|
|
from cuml.linear_model import LogisticRegression as cuLogReg |
|
|
seed_everything(1986) |
|
|
|
|
|
|
|
|
def to_gpu(X: np.ndarray): |
|
|
if isinstance(X, cp.ndarray): |
|
|
return X |
|
|
return cp.asarray(X, dtype=cp.float32) |
|
|
|
|
|
def to_cpu(x): |
|
|
if isinstance(x, cp.ndarray): |
|
|
return cp.asnumpy(x) |
|
|
return np.asarray(x) |
|
|
|
|
|
@dataclass |
|
|
class SplitData: |
|
|
X_train: np.ndarray |
|
|
y_train: np.ndarray |
|
|
seq_train: Optional[np.ndarray] |
|
|
X_val: np.ndarray |
|
|
y_val: np.ndarray |
|
|
seq_val: Optional[np.ndarray] |
|
|
|
|
|
|
|
|
def _stack_embeddings(col) -> np.ndarray: |
|
|
arr = np.asarray(col, dtype=np.float32) |
|
|
if arr.ndim != 2: |
|
|
arr = np.stack(col).astype(np.float32) |
|
|
return arr |
|
|
|
|
|
|
|
|
def load_split_data(dataset_path: str) -> SplitData: |
|
|
ds = load_from_disk(dataset_path) |
|
|
|
|
|
|
|
|
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds: |
|
|
train_ds, val_ds = ds["train"], ds["val"] |
|
|
else: |
|
|
|
|
|
if "split" not in ds.column_names: |
|
|
raise ValueError( |
|
|
"Dataset must be a DatasetDict(train/val) or have a 'split' column." |
|
|
) |
|
|
train_ds = ds.filter(lambda x: x["split"] == "train") |
|
|
val_ds = ds.filter(lambda x: x["split"] == "val") |
|
|
|
|
|
for required in ["embedding", "label"]: |
|
|
if required not in train_ds.column_names: |
|
|
raise ValueError(f"Missing column '{required}' in train split.") |
|
|
if required not in val_ds.column_names: |
|
|
raise ValueError(f"Missing column '{required}' in val split.") |
|
|
|
|
|
X_train = _stack_embeddings(train_ds["embedding"]) |
|
|
y_train = np.asarray(train_ds["label"], dtype=np.int64) |
|
|
|
|
|
X_val = _stack_embeddings(val_ds["embedding"]) |
|
|
y_val = np.asarray(val_ds["label"], dtype=np.int64) |
|
|
|
|
|
seq_train = None |
|
|
seq_val = None |
|
|
if "sequence" in train_ds.column_names: |
|
|
seq_train = np.asarray(train_ds["sequence"]) |
|
|
if "sequence" in val_ds.column_names: |
|
|
seq_val = np.asarray(val_ds["sequence"]) |
|
|
|
|
|
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val) |
|
|
|
|
|
|
|
|
def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]: |
|
|
""" |
|
|
Find threshold maximizing F1 on the given set. |
|
|
Returns (best_threshold, best_f1). |
|
|
""" |
|
|
precision, recall, thresholds = precision_recall_curve(y_true, y_prob) |
|
|
f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12) |
|
|
best_idx = int(np.nanargmax(f1s)) |
|
|
return float(thresholds[best_idx]), float(f1s[best_idx]) |
|
|
|
|
|
|
|
|
def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]: |
|
|
y_pred = (y_prob >= threshold).astype(int) |
|
|
return { |
|
|
"f1": float(f1_score(y_true, y_pred)), |
|
|
"auc": float(roc_auc_score(y_true, y_prob)), |
|
|
"ap": float(average_precision_score(y_true, y_prob)), |
|
|
"threshold": float(threshold), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_xgb( |
|
|
X_train, y_train, X_val, y_val, params: Dict[str, Any] |
|
|
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]: |
|
|
dtrain = xgb.DMatrix(X_train, label=y_train) |
|
|
dval = xgb.DMatrix(X_val, label=y_val) |
|
|
|
|
|
num_boost_round = int(params.pop("num_boost_round")) |
|
|
early_stopping_rounds = int(params.pop("early_stopping_rounds")) |
|
|
|
|
|
booster = xgb.train( |
|
|
params=params, |
|
|
dtrain=dtrain, |
|
|
num_boost_round=num_boost_round, |
|
|
evals=[(dval, "val")], |
|
|
early_stopping_rounds=early_stopping_rounds, |
|
|
verbose_eval=False, |
|
|
) |
|
|
|
|
|
p_train = booster.predict(dtrain) |
|
|
p_val = booster.predict(dval) |
|
|
return booster, p_train, p_val |
|
|
|
|
|
def train_cuml_svc(X_train, y_train, X_val, y_val, params): |
|
|
Xtr = to_gpu(X_train) |
|
|
Xva = to_gpu(X_val) |
|
|
ytr = to_gpu(y_train).astype(cp.int32) |
|
|
|
|
|
clf = cuSVC( |
|
|
C=float(params["C"]), |
|
|
kernel=params["kernel"], |
|
|
gamma=params.get("gamma", "scale"), |
|
|
class_weight=params.get("class_weight", None), |
|
|
probability=bool(params.get("probability", True)), |
|
|
random_state=1986, |
|
|
max_iter=int(params.get("max_iter", 1000)), |
|
|
tol=float(params.get("tol", 1e-4)), |
|
|
) |
|
|
|
|
|
clf.fit(Xtr, ytr) |
|
|
|
|
|
p_train = to_cpu(clf.predict_proba(Xtr)[:, 1]) |
|
|
p_val = to_cpu(clf.predict_proba(Xva)[:, 1]) |
|
|
return clf, p_train, p_val |
|
|
|
|
|
def train_cuml_elastic_net(X_train, y_train, X_val, y_val, params): |
|
|
Xtr = to_gpu(X_train) |
|
|
Xva = to_gpu(X_val) |
|
|
ytr = to_gpu(y_train).astype(cp.int32) |
|
|
|
|
|
clf = cuLogReg( |
|
|
penalty="elasticnet", |
|
|
C=float(params["C"]), |
|
|
l1_ratio=float(params["l1_ratio"]), |
|
|
class_weight=params.get("class_weight", None), |
|
|
max_iter=int(params.get("max_iter", 1000)), |
|
|
tol=float(params.get("tol", 1e-4)), |
|
|
solver="qn", |
|
|
fit_intercept=True, |
|
|
) |
|
|
clf.fit(Xtr, ytr) |
|
|
|
|
|
p_train = to_cpu(clf.predict_proba(Xtr)[:, 1]) |
|
|
p_val = to_cpu(clf.predict_proba(Xva)[:, 1]) |
|
|
return clf, p_train, p_val |
|
|
|
|
|
|
|
|
def train_svm(X_train, y_train, X_val, y_val, params): |
|
|
""" |
|
|
Kernel SVM via SVC. CPU only in sklearn. |
|
|
probability=True enables predict_proba but is slower. |
|
|
""" |
|
|
clf = SVC( |
|
|
C=float(params["C"]), |
|
|
kernel=params["kernel"], |
|
|
gamma=params.get("gamma", "scale"), |
|
|
class_weight=params.get("class_weight", None), |
|
|
probability=True, |
|
|
random_state=1986, |
|
|
) |
|
|
clf.fit(X_train, y_train) |
|
|
p_train = clf.predict_proba(X_train)[:, 1] |
|
|
p_val = clf.predict_proba(X_val)[:, 1] |
|
|
return clf, p_train, p_val |
|
|
|
|
|
|
|
|
def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params): |
|
|
""" |
|
|
Fast linear SVM (LinearSVC) + probability calibration. |
|
|
Usually much faster than SVC on large datasets. |
|
|
""" |
|
|
base = LinearSVC( |
|
|
C=float(params["C"]), |
|
|
class_weight=params.get("class_weight", None), |
|
|
max_iter=int(params.get("max_iter", 5000)), |
|
|
random_state=1986, |
|
|
) |
|
|
|
|
|
clf = CalibratedClassifierCV(base, method="sigmoid", cv=3) |
|
|
clf.fit(X_train, y_train) |
|
|
p_train = clf.predict_proba(X_train)[:, 1] |
|
|
p_val = clf.predict_proba(X_val)[:, 1] |
|
|
return clf, p_train, p_val |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_predictions_csv( |
|
|
out_dir: str, |
|
|
split_name: str, |
|
|
y_true: np.ndarray, |
|
|
y_prob: np.ndarray, |
|
|
threshold: float, |
|
|
sequences: Optional[np.ndarray] = None, |
|
|
): |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
df = pd.DataFrame({ |
|
|
"y_true": y_true.astype(int), |
|
|
"y_prob": y_prob.astype(float), |
|
|
"y_pred": (y_prob >= threshold).astype(int), |
|
|
}) |
|
|
if sequences is not None: |
|
|
df.insert(0, "sequence", sequences) |
|
|
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False) |
|
|
|
|
|
|
|
|
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray): |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
precision, recall, _ = precision_recall_curve(y_true, y_prob) |
|
|
plt.figure() |
|
|
plt.plot(recall, precision) |
|
|
plt.xlabel("Recall") |
|
|
plt.ylabel("Precision") |
|
|
plt.title("Precision-Recall Curve") |
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(out_dir, "pr_curve.png")) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
fpr, tpr, _ = roc_curve(y_true, y_prob) |
|
|
plt.figure() |
|
|
plt.plot(fpr, tpr) |
|
|
plt.xlabel("False Positive Rate") |
|
|
plt.ylabel("True Positive Rate") |
|
|
plt.title("ROC Curve") |
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(out_dir, "roc_curve.png")) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_objective(model_name: str, data: SplitData, out_dir: str): |
|
|
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val |
|
|
|
|
|
def objective(trial: optuna.Trial) -> float: |
|
|
if model_name == "xgb": |
|
|
params = { |
|
|
"objective": "binary:logistic", |
|
|
"eval_metric": "logloss", |
|
|
"lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True), |
|
|
"alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True), |
|
|
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0), |
|
|
"subsample": trial.suggest_float("subsample", 0.5, 1.0), |
|
|
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True), |
|
|
"max_depth": trial.suggest_int("max_depth", 2, 15), |
|
|
"min_child_weight": trial.suggest_int("min_child_weight", 1, 500), |
|
|
"gamma": trial.suggest_float("gamma", 0.0, 10.0), |
|
|
"tree_method": "hist", |
|
|
"device": "cuda", |
|
|
} |
|
|
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500) |
|
|
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200) |
|
|
|
|
|
model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy()) |
|
|
|
|
|
elif model_name == "svm": |
|
|
svm_kind = trial.suggest_categorical("svm_kind", ["svc", "linear_calibrated"]) |
|
|
|
|
|
if svm_kind == "svc": |
|
|
params = { |
|
|
"C": trial.suggest_float("C", 1e-3, 1e3, log=True), |
|
|
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]), |
|
|
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]), |
|
|
} |
|
|
if params["kernel"] in ["rbf", "poly", "sigmoid"]: |
|
|
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True) |
|
|
else: |
|
|
params["gamma"] = "scale" |
|
|
|
|
|
model, p_tr, p_va = train_svm(Xtr, ytr, Xva, yva, params) |
|
|
|
|
|
else: |
|
|
params = { |
|
|
"C": trial.suggest_float("C", 1e-3, 1e3, log=True), |
|
|
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]), |
|
|
"max_iter": trial.suggest_int("max_iter", 2000, 20000), |
|
|
} |
|
|
model, p_tr, p_va = train_linearsvm_calibrated(Xtr, ytr, Xva, yva, params) |
|
|
elif model_name == "svm_gpu": |
|
|
params = { |
|
|
"C": trial.suggest_float("C", 1e-3, 1e3, log=True), |
|
|
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]), |
|
|
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]), |
|
|
"probability": True, |
|
|
"max_iter": trial.suggest_int("max_iter", 200, 5000), |
|
|
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True), |
|
|
} |
|
|
if params["kernel"] in ["rbf", "poly", "sigmoid"]: |
|
|
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True) |
|
|
else: |
|
|
params["gamma"] = "scale" |
|
|
|
|
|
model, p_tr, p_va = train_cuml_svc(Xtr, ytr, Xva, yva, params) |
|
|
|
|
|
elif model_name == "enet_gpu": |
|
|
params = { |
|
|
"C": trial.suggest_float("C", 1e-4, 1e3, log=True), |
|
|
"l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0), |
|
|
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]), |
|
|
"max_iter": trial.suggest_int("max_iter", 200, 5000), |
|
|
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True), |
|
|
} |
|
|
model, p_tr, p_va = train_cuml_elastic_net(Xtr, ytr, Xva, yva, params) |
|
|
else: |
|
|
raise ValueError(f"Unknown model_name={model_name}") |
|
|
|
|
|
thr, f1_at_thr = best_f1_threshold(yva, p_va) |
|
|
metrics = eval_binary(yva, p_va, thr) |
|
|
trial.set_user_attr("threshold", thr) |
|
|
trial.set_user_attr("auc", metrics["auc"]) |
|
|
trial.set_user_attr("ap", metrics["ap"]) |
|
|
return f1_at_thr |
|
|
|
|
|
return objective |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_optuna_and_refit( |
|
|
dataset_path: str, |
|
|
out_dir: str, |
|
|
model_name: str, |
|
|
n_trials: int = 200, |
|
|
): |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
data = load_split_data(dataset_path) |
|
|
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}") |
|
|
|
|
|
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) |
|
|
study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials) |
|
|
|
|
|
trials_df = study.trials_dataframe() |
|
|
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False) |
|
|
|
|
|
best = study.best_trial |
|
|
best_params = dict(best.params) |
|
|
best_thr = float(best.user_attrs["threshold"]) |
|
|
best_auc = float(best.user_attrs["auc"]) |
|
|
best_ap = float(best.user_attrs["ap"]) |
|
|
best_f1 = float(best.value) |
|
|
|
|
|
|
|
|
if model_name == "xgb": |
|
|
params = { |
|
|
"objective": "binary:logistic", |
|
|
"eval_metric": "logloss", |
|
|
"lambda": best_params["lambda"], |
|
|
"alpha": best_params["alpha"], |
|
|
"colsample_bytree": best_params["colsample_bytree"], |
|
|
"subsample": best_params["subsample"], |
|
|
"learning_rate": best_params["learning_rate"], |
|
|
"max_depth": best_params["max_depth"], |
|
|
"min_child_weight": best_params["min_child_weight"], |
|
|
"gamma": best_params["gamma"], |
|
|
"tree_method": "hist", |
|
|
"num_boost_round": best_params["num_boost_round"], |
|
|
"early_stopping_rounds": best_params["early_stopping_rounds"], |
|
|
} |
|
|
model, p_tr, p_va = train_xgb( |
|
|
data.X_train, data.y_train, data.X_val, data.y_val, params |
|
|
) |
|
|
model_path = os.path.join(out_dir, "best_model.json") |
|
|
model.save_model(model_path) |
|
|
|
|
|
elif model_name == "svm": |
|
|
svm_kind = best_params["svm_kind"] |
|
|
if svm_kind == "svc": |
|
|
model, p_tr, p_va = train_svm(data.X_train, data.y_train, data.X_val, data.y_val, best_params) |
|
|
else: |
|
|
model, p_tr, p_va = train_linearsvm_calibrated(data.X_train, data.y_train, data.X_val, data.y_val, best_params) |
|
|
|
|
|
model_path = os.path.join(out_dir, "best_model.joblib") |
|
|
joblib.dump(model, model_path) |
|
|
elif model_name == "svm_gpu": |
|
|
model, p_tr, p_va = train_cuml_svc( |
|
|
data.X_train, data.y_train, data.X_val, data.y_val, best_params |
|
|
) |
|
|
model_path = os.path.join(out_dir, "best_model_cuml_svc.joblib") |
|
|
joblib.dump(model, model_path) |
|
|
|
|
|
elif model_name == "enet_gpu": |
|
|
model, p_tr, p_va = train_cuml_elastic_net( |
|
|
data.X_train, data.y_train, data.X_val, data.y_val, best_params |
|
|
) |
|
|
model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib") |
|
|
joblib.dump(model, model_path) |
|
|
else: |
|
|
raise ValueError(model_name) |
|
|
|
|
|
|
|
|
save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train) |
|
|
save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val) |
|
|
|
|
|
|
|
|
plot_curves(out_dir, data.y_val, p_va) |
|
|
|
|
|
summary = [ |
|
|
"=" * 72, |
|
|
f"MODEL: {model_name}", |
|
|
f"Best trial: {best.number}", |
|
|
f"Best F1 (val @ best-threshold): {best_f1:.4f}", |
|
|
f"Val AUC: {best_auc:.4f}", |
|
|
f"Val AP: {best_ap:.4f}", |
|
|
f"Best threshold (picked on val): {best_thr:.4f}", |
|
|
f"Model saved to: {model_path}", |
|
|
"Best params:", |
|
|
json.dumps(best_params, indent=2), |
|
|
"=" * 72, |
|
|
] |
|
|
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f: |
|
|
f.write("\n".join(summary)) |
|
|
print("\n".join(summary)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--dataset_path", type=str, required=True) |
|
|
parser.add_argument("--out_dir", type=str, required=True) |
|
|
parser.add_argument("--model", type=str, choices=["xgb", "svm_gpu", "enet_gpu"], required=True) |
|
|
parser.add_argument("--n_trials", type=int, default=200) |
|
|
args = parser.parse_args() |
|
|
|
|
|
run_optuna_and_refit( |
|
|
dataset_path=args.dataset_path, |
|
|
out_dir=args.out_dir, |
|
|
model_name=args.model, |
|
|
n_trials=args.n_trials, |
|
|
) |
|
|
|