Joblib
PeptiVerse / training_classifiers /train_ml_regression.py
ynuozhang
update code
baf3373
import os
import json
import joblib
import optuna
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, Any, Tuple, Optional
from datasets import load_from_disk, DatasetDict
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.svm import SVR
import xgboost as xgb
from lightning.pytorch import seed_everything
import cupy as cp
from cuml.linear_model import ElasticNet as cuElasticNet
from scipy.stats import spearmanr
seed_everything(1986)
# -----------------------------
# GPU/CPU helpers
# -----------------------------
def to_gpu(X: np.ndarray):
if isinstance(X, cp.ndarray):
return X
return cp.asarray(X, dtype=cp.float32)
def to_cpu(x):
if isinstance(x, cp.ndarray):
return cp.asnumpy(x)
return np.asarray(x)
# -----------------------------
# Data loading
# -----------------------------
@dataclass
class SplitData:
X_train: np.ndarray
y_train: np.ndarray
seq_train: Optional[np.ndarray]
X_val: np.ndarray
y_val: np.ndarray
seq_val: Optional[np.ndarray]
def _stack_embeddings(col) -> np.ndarray:
arr = np.asarray(col, dtype=np.float32)
if arr.ndim != 2:
arr = np.stack(col).astype(np.float32)
return arr
def load_split_data(dataset_path: str) -> SplitData:
ds = load_from_disk(dataset_path)
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
train_ds, val_ds = ds["train"], ds["val"]
else:
if "split" not in ds.column_names:
raise ValueError("Dataset must be a DatasetDict(train/val) or have a 'split' column.")
train_ds = ds.filter(lambda x: x["split"] == "train")
val_ds = ds.filter(lambda x: x["split"] == "val")
for required in ["embedding", "label"]:
if required not in train_ds.column_names:
raise ValueError(f"Missing column '{required}' in train split.")
if required not in val_ds.column_names:
raise ValueError(f"Missing column '{required}' in val split.")
X_train = _stack_embeddings(train_ds["embedding"]).astype(np.float32)
X_val = _stack_embeddings(val_ds["embedding"]).astype(np.float32)
y_train = np.asarray(train_ds["label"], dtype=np.float32)
y_val = np.asarray(val_ds["label"], dtype=np.float32)
seq_train = None
seq_val = None
if "sequence" in train_ds.column_names:
seq_train = np.asarray(train_ds["sequence"])
if "sequence" in val_ds.column_names:
seq_val = np.asarray(val_ds["sequence"])
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
# -----------------------------
# Metrics
# -----------------------------
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
rho = spearmanr(y_true, y_pred).correlation
if rho is None or np.isnan(rho):
return 0.0
return float(rho)
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
# RMSE
try:
from sklearn.metrics import root_mean_squared_error
rmse = root_mean_squared_error(y_true, y_pred)
except Exception:
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
mae = float(mean_absolute_error(y_true, y_pred))
r2 = float(r2_score(y_true, y_pred))
rho = float(safe_spearmanr(y_true, y_pred))
return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
# -----------------------------
# Model
# -----------------------------
def train_xgb_reg(
X_train, y_train, X_val, y_val, params: Dict[str, Any]
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
num_boost_round = int(params.pop("num_boost_round"))
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
booster = xgb.train(
params=params,
dtrain=dtrain,
num_boost_round=num_boost_round,
evals=[(dval, "val")],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=False,
)
p_train = booster.predict(dtrain)
p_val = booster.predict(dval)
return booster, p_train, p_val
def train_cuml_elasticnet_reg(
X_train, y_train, X_val, y_val, params: Dict[str, Any]
):
Xtr = to_gpu(X_train)
Xva = to_gpu(X_val)
ytr = to_gpu(y_train).astype(cp.float32)
model = cuElasticNet(
alpha=float(params["alpha"]),
l1_ratio=float(params["l1_ratio"]),
fit_intercept=True,
max_iter=int(params.get("max_iter", 5000)),
tol=float(params.get("tol", 1e-4)),
selection=params.get("selection", "cyclic"),
)
model.fit(Xtr, ytr)
p_train = to_cpu(model.predict(Xtr))
p_val = to_cpu(model.predict(Xva))
return model, p_train, p_val
def train_svr_reg(
X_train, y_train, X_val, y_val, params: Dict[str, Any]
):
model = SVR(
C=float(params["C"]),
epsilon=float(params["epsilon"]),
kernel=params["kernel"],
gamma=params.get("gamma", "scale"),
)
model.fit(X_train, y_train)
p_train = model.predict(X_train)
p_val = model.predict(X_val)
return model, p_train, p_val
# -----------------------------
# Saving + plots
# -----------------------------
def save_predictions_csv(
out_dir: str,
split_name: str,
y_true: np.ndarray,
y_pred: np.ndarray,
sequences: Optional[np.ndarray] = None,
):
os.makedirs(out_dir, exist_ok=True)
df = pd.DataFrame({
"y_true": y_true.astype(float),
"y_pred": y_pred.astype(float),
"residual": (y_true - y_pred).astype(float),
})
if sequences is not None:
df.insert(0, "sequence", sequences)
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
def plot_regression_diagnostics(out_dir: str, y_true: np.ndarray, y_pred: np.ndarray):
os.makedirs(out_dir, exist_ok=True)
plt.figure()
plt.scatter(y_true, y_pred, s=8, alpha=0.5)
plt.xlabel("y_true")
plt.ylabel("y_pred")
plt.title("Predicted vs True")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
plt.close()
resid = y_true - y_pred
plt.figure()
plt.hist(resid, bins=50)
plt.xlabel("residual (y_true - y_pred)")
plt.ylabel("count")
plt.title("Residual Histogram")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "residual_hist.png"))
plt.close()
plt.figure()
plt.scatter(y_pred, resid, s=8, alpha=0.5)
plt.xlabel("y_pred")
plt.ylabel("residual")
plt.title("Residuals vs Prediction")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
plt.close()
# -----------------------------
# Optuna objective (OPTIMIZE SPEARMAN RHO)
# -----------------------------
def make_objective(model_name: str, data: SplitData):
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
def objective(trial: optuna.Trial) -> float:
if model_name == "xgb_reg":
params = {
"objective": "reg:squarederror",
"eval_metric": "rmse",
"lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
"alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
"max_depth": trial.suggest_int("max_depth", 2, 16),
"min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 500.0, log=True),
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
"tree_method": "hist",
"device": "cuda",
}
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 2000)
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
model, p_tr, p_va = train_xgb_reg(Xtr, ytr, Xva, yva, params.copy())
elif model_name == "enet_gpu":
params = {
"alpha": trial.suggest_float("alpha", 1e-8, 10.0, log=True),
"l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
"max_iter": trial.suggest_int("max_iter", 1000, 20000),
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
"selection": trial.suggest_categorical("selection", ["cyclic", "random"]),
}
model, p_tr, p_va = train_cuml_elasticnet_reg(Xtr, ytr, Xva, yva, params)
elif model_name == "svr":
params = {
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
"epsilon": trial.suggest_float("epsilon", 1e-4, 1.0, log=True),
}
if params["kernel"] in ["rbf", "poly", "sigmoid"]:
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
else:
params["gamma"] = "scale"
model, p_tr, p_va = train_svr_reg(Xtr, ytr, Xva, yva, params)
else:
raise ValueError(f"Unknown model_name={model_name}")
metrics = eval_regression(yva, p_va)
trial.set_user_attr("spearman_rho", metrics["spearman_rho"])
trial.set_user_attr("rmse", metrics["rmse"])
trial.set_user_attr("mae", metrics["mae"])
trial.set_user_attr("r2", metrics["r2"])
# OPTUNA OBJECTIVE = maximize Spearman rho
return metrics["spearman_rho"]
return objective
# -----------------------------
# Main
# -----------------------------
def run_optuna_and_refit(
dataset_path: str,
out_dir: str,
model_name: str,
n_trials: int = 200,
standardize_X: bool = True,
):
os.makedirs(out_dir, exist_ok=True)
data = load_split_data(dataset_path)
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
# Standardize features (SVR + ElasticNet)
if standardize_X:
scaler = StandardScaler()
data.X_train = scaler.fit_transform(data.X_train).astype(np.float32)
data.X_val = scaler.transform(data.X_val).astype(np.float32)
joblib.dump(scaler, os.path.join(out_dir, "scaler.joblib"))
print("[Preprocess] Saved StandardScaler -> scaler.joblib")
study = optuna.create_study(
direction="maximize",
pruner=optuna.pruners.MedianPruner()
)
study.optimize(make_objective(model_name, data), n_trials=n_trials)
trials_df = study.trials_dataframe()
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
best = study.best_trial
best_params = dict(best.params)
best_rho = float(best.user_attrs.get("spearman_rho", best.value))
best_rmse = float(best.user_attrs.get("rmse", np.nan))
best_mae = float(best.user_attrs.get("mae", np.nan))
best_r2 = float(best.user_attrs.get("r2", np.nan))
# Refit best model on train
if model_name == "xgb_reg":
params = {
"objective": "reg:squarederror",
"eval_metric": "rmse",
"lambda": best_params["lambda"],
"alpha": best_params["alpha"],
"gamma": best_params["gamma"],
"max_depth": best_params["max_depth"],
"min_child_weight": best_params["min_child_weight"],
"subsample": best_params["subsample"],
"colsample_bytree": best_params["colsample_bytree"],
"learning_rate": best_params["learning_rate"],
"tree_method": "hist",
"device": "cuda",
"num_boost_round": best_params["num_boost_round"],
"early_stopping_rounds": best_params["early_stopping_rounds"],
}
model, p_tr, p_va = train_xgb_reg(
data.X_train, data.y_train, data.X_val, data.y_val, params
)
model_path = os.path.join(out_dir, "best_model.json")
model.save_model(model_path)
elif model_name == "enet_gpu":
model, p_tr, p_va = train_cuml_elasticnet_reg(
data.X_train, data.y_train, data.X_val, data.y_val, best_params
)
model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
joblib.dump(model, model_path)
elif model_name == "svr":
model, p_tr, p_va = train_svr_reg(
data.X_train, data.y_train, data.X_val, data.y_val, best_params
)
model_path = os.path.join(out_dir, "best_model_svr.joblib")
joblib.dump(model, model_path)
else:
raise ValueError(model_name)
save_predictions_csv(out_dir, "train", data.y_train, p_tr, data.seq_train)
save_predictions_csv(out_dir, "val", data.y_val, p_va, data.seq_val)
plot_regression_diagnostics(out_dir, data.y_val, p_va)
summary = [
"=" * 72,
f"MODEL: {model_name}",
f"Best trial: {best.number}",
f"Val Spearman rho (objective): {best_rho:.6f}",
f"Val RMSE: {best_rmse:.6f}",
f"Val MAE: {best_mae:.6f}",
f"Val R2: {best_r2:.6f}",
f"Model saved to: {model_path}",
"Best params:",
json.dumps(best_params, indent=2),
"=" * 72,
]
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
f.write("\n".join(summary))
print("\n".join(summary))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--model", type=str, choices=["xgb_reg", "enet_gpu", "svr"], required=True)
parser.add_argument("--n_trials", type=int, default=200)
parser.add_argument("--no_standardize", action="store_true", help="Disable StandardScaler on X")
args = parser.parse_args()
run_optuna_and_refit(
dataset_path=args.dataset_path,
out_dir=args.out_dir,
model_name=args.model,
n_trials=args.n_trials,
standardize_X=(not args.no_standardize),
)