Train Script Share
#1
by
BeniaDev
- opened
Hello there!
Can you please share your training script or .ipynb? and requirements will be really helpful
really intereseted how much GPU you use? 2 40GB A100 enough?
Thank you for considering my request. I look forward to your positive response.
I use one A40 (48 GB) GPU to fine-tune an 8B model, with lots of black magic — things like CachedMultipleNegativesRankingLoss, paged AdamW, and so on.
Here’s my script; feel free to adapt it to your needs.
"""Utilities for finetuning Qwen embedding models on the code retrieval dataset.
This script uses `sentence-transformers`' MultipleNegativesRankingLoss within the
new ``SentenceTransformerTrainer`` to learn better alignments between
natural-language queries and code snippets. It mirrors the inference-time setup
provided in `qwen.py` and can be invoked from the command line:
python qwen-finetune.py --model 0.6B --epochs 2 --batch-size 16
The script will automatically split the training CSV into train/validation sets,
log progress with Rich, and save the fine-tuned model to the requested output
directory, with the option to push the result to the Hugging Face Hub.
"""
from __future__ import annotations
import argparse
import random
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple
import numpy as np
import torch
from dataset import load_train_data
from datasets import Dataset
from rich.console import Console
from rich.table import Table
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
try:
import flash_attn
flash_attn_installed = True
except ImportError:
flash_attn_installed = False
pass
console = Console()
@dataclass
class TrainingConfig:
"""Serializable training configuration for reproducibility."""
model_variant: str
train_file: str
output_dir: str
run_name: str
epochs: int
batch_size: int
learning_rate: float
warmup_ratio: float
validation_split: float
eval_steps: int
checkpoint_steps: int
seed: int
mixed_precision: bool
push_to_hub: bool
hub_model_id: Optional[str]
hub_token: Optional[str]
hub_private_repo: bool
save_total_limit: Optional[int]
logging_steps: Optional[int]
gradient_accumulation_steps: int = 1
gradient_checkpointing: bool = False
optim: str = "adamw_torch"
def parse_arguments() -> TrainingConfig:
parser = argparse.ArgumentParser(
description="Finetune Qwen embedding models for code retrieval.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
epilog="""
Examples:
python qwen-finetune.py --model 0.6B --epochs 1 ;
python qwen-finetune.py --model 4B --batch-size 16 --learning-rate 2e-5 ;
python qwen-finetune.py --validation-split 0.1 --eval-steps 200 ;
python qwen-finetune.py --push-to-hub --hub-model-id username/qwen-0.6b-code
""",
)
parser.add_argument(
"--model",
choices=("0.6B", "4B", "8B"),
default="0.6B",
help="Base Qwen embedding model variant to finetune.",
)
parser.add_argument(
"--train-file",
default="data/train_queries.csv",
help="CSV file containing 'query' and 'code' columns.",
)
parser.add_argument(
"--output-dir",
default="models/qwen-finetuned",
help="Directory where the fine-tuned model checkpoints will be stored.",
)
parser.add_argument(
"--run-name",
default=None,
help="Optional run name. Defaults to a timestamped folder inside output-dir.",
)
parser.add_argument(
"--epochs", type=int, default=1, help="Number of training epochs."
)
parser.add_argument(
"--batch-size", type=int, default=8, help="Batch size for training dataloader."
)
parser.add_argument(
"--gradient-accumulation-steps",
type=int,
default=1,
help="Number of gradient accumulation steps to simulate larger batch sizes.",
)
parser.add_argument(
"--learning-rate",
type=float,
default=5e-5,
help="Learning rate passed to the AdamW optimizer.",
)
parser.add_argument(
"--warmup-ratio",
type=float,
default=0.1,
help="Portion of total steps used for LR warmup.",
)
parser.add_argument(
"--validation-split",
type=float,
default=0.1,
help="Fraction of data reserved for validation (0 disables validation).",
)
parser.add_argument(
"--eval-steps",
type=int,
default=0,
help=(
"How frequently to run validation (in training steps). Set to 0 to only evaluate at the end."
),
)
parser.add_argument(
"--checkpoint-steps",
type=int,
default=0,
help="How frequently to save intermediate checkpoints. 0 disables checkpointing.",
)
parser.add_argument(
"--save-total-limit",
type=int,
default=None,
help="Maximum number of checkpoints to keep on disk (None disables cleanup).",
)
parser.add_argument(
"--seed",
type=int,
default=2025,
help="Random seed for shuffling and weight initialisation.",
)
parser.add_argument(
"--mixed-precision",
action="store_true",
help="Enable mixed precision (fp16/amp) training if supported.",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the fine-tuned model to the Hugging Face Hub after training.",
)
parser.add_argument(
"--hub-model-id",
default=None,
help="Target repository name on the Hugging Face Hub (e.g. username/model-name).",
)
parser.add_argument(
"--hub-token",
default=None,
help="Optional Hugging Face access token used for hub operations.",
)
parser.add_argument(
"--hub-private-repo",
action="store_true",
help="Create the Hub repository as private when pushing for the first time.",
)
parser.add_argument(
"--logging-steps",
type=int,
default=None,
help="How frequently to log training metrics (None keeps the library default).",
)
parser.add_argument(
"--gradient-checkpointing",
action="store_true",
help="Enable gradient checkpointing to trade extra compute for lower activation memory.",
)
parser.add_argument(
"--optim",
choices=(
"adamw_torch",
"adamw_torch_fused",
"adamw_apex_fused",
"adamw_hf",
"adamw_anyprecision",
"adamw_bnb_8bit",
"paged_adamw_8bit",
"paged_adamw_32bit",
),
default="adamw_torch",
help="Optimizer to use during training. 8-bit and paged variants require bitsandbytes.",
)
args = parser.parse_args()
run_name = args.run_name or datetime.now().strftime("%Y%m%d-%H%M%S")
return TrainingConfig(
model_variant=args.model,
train_file=args.train_file,
output_dir=args.output_dir,
run_name=run_name,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
validation_split=args.validation_split,
eval_steps=args.eval_steps,
checkpoint_steps=args.checkpoint_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
mixed_precision=args.mixed_precision,
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
hub_token=args.hub_token,
hub_private_repo=args.hub_private_repo,
logging_steps=args.logging_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
optim=args.optim,
)
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_qwen_model(model_variant: str) -> SentenceTransformer:
model_map = {
"0.6B": "Qwen/Qwen3-Embedding-0.6B",
"4B": "Qwen/Qwen3-Embedding-4B",
"8B": "Qwen/Qwen3-Embedding-8B",
}
if model_variant not in model_map:
raise ValueError(
f"Unknown Qwen model variant '{model_variant}'. Choose from {list(model_map)}."
)
model_name = model_map[model_variant]
console.print(f"[bold green]Loading base model:[/bold green] {model_name}")
return (
SentenceTransformer(
model_name,
model_kwargs={
"attn_implementation": "flash_attention_2",
"dtype": torch.bfloat16,
"device_map": "auto",
},
tokenizer_kwargs={"padding_side": "left"},
)
if flash_attn_installed
else SentenceTransformer(model_name)
)
def split_train_validation(
df, validation_split: float, seed: int
) -> Tuple[torch.utils.data.Dataset, Optional[torch.utils.data.Dataset]]:
if validation_split <= 0:
return df, None
val_size = int(len(df) * validation_split)
if val_size == 0:
console.print(
"[yellow]Validation split too small for the given dataset; proceeding without validation.[/yellow]"
)
return df, None
df_shuffled = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
val_df = df_shuffled.iloc[:val_size].reset_index(drop=True)
train_df = df_shuffled.iloc[val_size:].reset_index(drop=True)
return train_df, val_df
def build_ir_evaluator(df) -> InformationRetrievalEvaluator:
queries: Dict[str, str] = {}
corpus: Dict[str, str] = {}
relevant_docs: Dict[str, Dict[str, int]] = {}
for idx, row in enumerate(df.itertuples(index=False)):
query_id = f"q{idx}"
code_id = f"c{idx}"
queries[query_id] = str(getattr(row, "query", ""))
corpus[code_id] = str(getattr(row, "code", ""))
relevant_docs[query_id] = {code_id: 1}
return InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
mrr_at_k=[1, 5, 10],
ndcg_at_k=[1, 5, 10],
accuracy_at_k=[1, 5, 10],
show_progress_bar=True,
batch_size=1,
)
def log_dataset_summary(train_df, val_df) -> None:
table = Table(title="Dataset Summary")
table.add_column("Split", justify="left", style="cyan", no_wrap=True)
table.add_column("Samples", justify="right", style="magenta")
table.add_row("Train", f"{len(train_df)}")
if val_df is not None:
table.add_row("Validation", f"{len(val_df)}")
else:
table.add_row("Validation", "0 (disabled)")
console.print(table)
def save_config(config: TrainingConfig, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
config_path = output_dir / "training_config.json"
with config_path.open("w", encoding="utf-8") as fp:
import json
json.dump(asdict(config), fp, indent=2)
console.print(
f"[bold green]Saved training configuration to:[/bold green] {config_path}"
)
def main() -> None:
config = parse_arguments()
console.print("[bold magenta]Starting Qwen embedding finetuning[/bold magenta]")
console.print(f"• Model variant: [bold]{config.model_variant}[/bold]")
console.print(f"• Training file: [bold]{config.train_file}[/bold]")
set_seed(config.seed)
train_df = load_train_data(config.train_file)
train_df = train_df.dropna(subset=["query", "code"]).reset_index(drop=True)
if train_df.empty:
raise RuntimeError("Training dataset is empty after dropping NA rows.")
train_df, val_df = split_train_validation(
train_df, config.validation_split, config.seed
)
log_dataset_summary(train_df, val_df)
model = load_qwen_model(config.model_variant)
if config.gradient_checkpointing:
if hasattr(model, "gradient_checkpointing_enable"):
console.print(
"[bold cyan]Gradient checkpointing enabled for the encoder.[/bold cyan]"
)
model.gradient_checkpointing_enable()
else:
console.print(
"[yellow]Gradient checkpointing requested but unsupported by the current model.[/yellow]"
)
train_loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=1)
base_output_dir = Path(config.output_dir)
run_output_dir = base_output_dir / config.run_name
run_output_dir.mkdir(parents=True, exist_ok=True)
save_config(config, run_output_dir)
evaluator = None
if val_df is not None and not val_df.empty:
console.print(
"[bold cyan]Validation enabled:[/bold cyan] InformationRetrievalEvaluator"
)
evaluator = build_ir_evaluator(val_df)
train_dataset = Dataset.from_pandas(
train_df[["query", "code"]], preserve_index=False
)
eval_dataset = None
if val_df is not None and not val_df.empty:
eval_dataset = Dataset.from_pandas(
val_df[["query", "code"]], preserve_index=False
)
eval_strategy = "no"
training_args_kwargs = {}
if evaluator is not None:
if config.eval_steps > 0:
eval_strategy = "steps"
training_args_kwargs["eval_steps"] = config.eval_steps
else:
eval_strategy = "epoch"
if config.checkpoint_steps > 0:
save_strategy = "steps"
training_args_kwargs["save_steps"] = config.checkpoint_steps
else:
save_strategy = "epoch" if evaluator is not None else "no"
if config.logging_steps is not None:
training_args_kwargs["logging_steps"] = config.logging_steps
if config.save_total_limit is not None:
training_args_kwargs["save_total_limit"] = config.save_total_limit
bitsandbytes_optimizers = {
"adamw_bnb_8bit",
"paged_adamw_8bit",
"paged_adamw_32bit",
}
if config.optim in bitsandbytes_optimizers:
try:
import bitsandbytes as _bnb # noqa: F401
except ImportError as exc:
raise RuntimeError(
f"Optimizer '{config.optim}' requires the bitsandbytes package. Install it with `pip install bitsandbytes`."
) from exc
training_args = SentenceTransformerTrainingArguments(
output_dir=str(run_output_dir),
run_name=config.run_name,
num_train_epochs=config.epochs,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
warmup_ratio=config.warmup_ratio,
eval_strategy=eval_strategy,
eval_on_start=True,
save_strategy=save_strategy,
seed=config.seed,
fp16=config.mixed_precision,
bf16=flash_attn_installed,
gradient_checkpointing=config.gradient_checkpointing,
optim=config.optim,
batch_sampler=BatchSamplers.NO_DUPLICATES,
push_to_hub=config.push_to_hub,
hub_model_id=config.hub_model_id,
hub_token=config.hub_token,
hub_private_repo=config.hub_private_repo,
report_to=["tensorboard"],
load_best_model_at_end=evaluator is not None
and save_strategy != "no"
and eval_strategy != "no",
**training_args_kwargs,
)
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=evaluator,
)
console.print("[bold green]Commencing training...[/bold green]")
trainer.train()
console.print("[bold green]Saving final model checkpoint...[/bold green]")
trainer.save_model()
if evaluator is not None:
console.print(
"[bold cyan]Running final evaluation on validation set...[/bold cyan]"
)
scores = evaluator(trainer.model, output_path=str(run_output_dir))
for metric, score in scores.items():
console.print(f" • {metric}: [magenta]{score:.4f}[/magenta]")
if config.push_to_hub:
console.print("[bold blue]Pushing model to the Hugging Face Hub...[/bold blue]")
trainer.push_to_hub()
console.print(
f"[bold green]Training complete.[/bold green] Fine-tuned model saved to [bold]{run_output_dir}[/bold]"
)
if __name__ == "__main__":
main()
oh, thats great ! Thank you very much