Train Script Share

#1
by BeniaDev - opened

Hello there!

Can you please share your training script or .ipynb? and requirements will be really helpful

really intereseted how much GPU you use? 2 40GB A100 enough?

Thank you for considering my request. I look forward to your positive response.

I use one A40 (48 GB) GPU to fine-tune an 8B model, with lots of black magic — things like CachedMultipleNegativesRankingLoss, paged AdamW, and so on.

Here’s my script; feel free to adapt it to your needs.

"""Utilities for finetuning Qwen embedding models on the code retrieval dataset.

This script uses `sentence-transformers`' MultipleNegativesRankingLoss within the
new ``SentenceTransformerTrainer`` to learn better alignments between
natural-language queries and code snippets. It mirrors the inference-time setup
provided in `qwen.py` and can be invoked from the command line:

        python qwen-finetune.py --model 0.6B --epochs 2 --batch-size 16

The script will automatically split the training CSV into train/validation sets,
log progress with Rich, and save the fine-tuned model to the requested output
directory, with the option to push the result to the Hugging Face Hub.
"""

from __future__ import annotations

import argparse
import random
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import torch
from dataset import load_train_data
from datasets import Dataset
from rich.console import Console
from rich.table import Table
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

try:
    import flash_attn

    flash_attn_installed = True
except ImportError:
    flash_attn_installed = False
    pass


console = Console()


@dataclass
class TrainingConfig:
    """Serializable training configuration for reproducibility."""

    model_variant: str
    train_file: str
    output_dir: str
    run_name: str
    epochs: int
    batch_size: int
    learning_rate: float
    warmup_ratio: float
    validation_split: float
    eval_steps: int
    checkpoint_steps: int
    seed: int
    mixed_precision: bool
    push_to_hub: bool
    hub_model_id: Optional[str]
    hub_token: Optional[str]
    hub_private_repo: bool
    save_total_limit: Optional[int]
    logging_steps: Optional[int]
    gradient_accumulation_steps: int = 1
    gradient_checkpointing: bool = False
    optim: str = "adamw_torch"


def parse_arguments() -> TrainingConfig:
    parser = argparse.ArgumentParser(
        description="Finetune Qwen embedding models for code retrieval.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        epilog="""
Examples:
    python qwen-finetune.py --model 0.6B --epochs 1 ;
    python qwen-finetune.py --model 4B --batch-size 16 --learning-rate 2e-5 ;
    python qwen-finetune.py --validation-split 0.1 --eval-steps 200 ;
    python qwen-finetune.py --push-to-hub --hub-model-id username/qwen-0.6b-code
        """,
    )

    parser.add_argument(
        "--model",
        choices=("0.6B", "4B", "8B"),
        default="0.6B",
        help="Base Qwen embedding model variant to finetune.",
    )
    parser.add_argument(
        "--train-file",
        default="data/train_queries.csv",
        help="CSV file containing 'query' and 'code' columns.",
    )
    parser.add_argument(
        "--output-dir",
        default="models/qwen-finetuned",
        help="Directory where the fine-tuned model checkpoints will be stored.",
    )
    parser.add_argument(
        "--run-name",
        default=None,
        help="Optional run name. Defaults to a timestamped folder inside output-dir.",
    )
    parser.add_argument(
        "--epochs", type=int, default=1, help="Number of training epochs."
    )
    parser.add_argument(
        "--batch-size", type=int, default=8, help="Batch size for training dataloader."
    )
    parser.add_argument(
        "--gradient-accumulation-steps",
        type=int,
        default=1,
        help="Number of gradient accumulation steps to simulate larger batch sizes.",
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=5e-5,
        help="Learning rate passed to the AdamW optimizer.",
    )
    parser.add_argument(
        "--warmup-ratio",
        type=float,
        default=0.1,
        help="Portion of total steps used for LR warmup.",
    )
    parser.add_argument(
        "--validation-split",
        type=float,
        default=0.1,
        help="Fraction of data reserved for validation (0 disables validation).",
    )
    parser.add_argument(
        "--eval-steps",
        type=int,
        default=0,
        help=(
            "How frequently to run validation (in training steps). Set to 0 to only evaluate at the end."
        ),
    )
    parser.add_argument(
        "--checkpoint-steps",
        type=int,
        default=0,
        help="How frequently to save intermediate checkpoints. 0 disables checkpointing.",
    )
    parser.add_argument(
        "--save-total-limit",
        type=int,
        default=None,
        help="Maximum number of checkpoints to keep on disk (None disables cleanup).",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=2025,
        help="Random seed for shuffling and weight initialisation.",
    )
    parser.add_argument(
        "--mixed-precision",
        action="store_true",
        help="Enable mixed precision (fp16/amp) training if supported.",
    )
    parser.add_argument(
        "--push-to-hub",
        action="store_true",
        help="Push the fine-tuned model to the Hugging Face Hub after training.",
    )
    parser.add_argument(
        "--hub-model-id",
        default=None,
        help="Target repository name on the Hugging Face Hub (e.g. username/model-name).",
    )
    parser.add_argument(
        "--hub-token",
        default=None,
        help="Optional Hugging Face access token used for hub operations.",
    )
    parser.add_argument(
        "--hub-private-repo",
        action="store_true",
        help="Create the Hub repository as private when pushing for the first time.",
    )
    parser.add_argument(
        "--logging-steps",
        type=int,
        default=None,
        help="How frequently to log training metrics (None keeps the library default).",
    )
    parser.add_argument(
        "--gradient-checkpointing",
        action="store_true",
        help="Enable gradient checkpointing to trade extra compute for lower activation memory.",
    )
    parser.add_argument(
        "--optim",
        choices=(
            "adamw_torch",
            "adamw_torch_fused",
            "adamw_apex_fused",
            "adamw_hf",
            "adamw_anyprecision",
            "adamw_bnb_8bit",
            "paged_adamw_8bit",
            "paged_adamw_32bit",
        ),
        default="adamw_torch",
        help="Optimizer to use during training. 8-bit and paged variants require bitsandbytes.",
    )

    args = parser.parse_args()
    run_name = args.run_name or datetime.now().strftime("%Y%m%d-%H%M%S")

    return TrainingConfig(
        model_variant=args.model,
        train_file=args.train_file,
        output_dir=args.output_dir,
        run_name=run_name,
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        warmup_ratio=args.warmup_ratio,
        validation_split=args.validation_split,
        eval_steps=args.eval_steps,
        checkpoint_steps=args.checkpoint_steps,
        save_total_limit=args.save_total_limit,
        seed=args.seed,
        mixed_precision=args.mixed_precision,
        push_to_hub=args.push_to_hub,
        hub_model_id=args.hub_model_id,
        hub_token=args.hub_token,
        hub_private_repo=args.hub_private_repo,
        logging_steps=args.logging_steps,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        gradient_checkpointing=args.gradient_checkpointing,
        optim=args.optim,
    )


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def load_qwen_model(model_variant: str) -> SentenceTransformer:
    model_map = {
        "0.6B": "Qwen/Qwen3-Embedding-0.6B",
        "4B": "Qwen/Qwen3-Embedding-4B",
        "8B": "Qwen/Qwen3-Embedding-8B",
    }

    if model_variant not in model_map:
        raise ValueError(
            f"Unknown Qwen model variant '{model_variant}'. Choose from {list(model_map)}."
        )

    model_name = model_map[model_variant]
    console.print(f"[bold green]Loading base model:[/bold green] {model_name}")
    return (
        SentenceTransformer(
            model_name,
            model_kwargs={
                "attn_implementation": "flash_attention_2",
                "dtype": torch.bfloat16,
                "device_map": "auto",
            },
            tokenizer_kwargs={"padding_side": "left"},
        )
        if flash_attn_installed
        else SentenceTransformer(model_name)
    )


def split_train_validation(
    df, validation_split: float, seed: int
) -> Tuple[torch.utils.data.Dataset, Optional[torch.utils.data.Dataset]]:
    if validation_split <= 0:
        return df, None

    val_size = int(len(df) * validation_split)
    if val_size == 0:
        console.print(
            "[yellow]Validation split too small for the given dataset; proceeding without validation.[/yellow]"
        )
        return df, None

    df_shuffled = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    val_df = df_shuffled.iloc[:val_size].reset_index(drop=True)
    train_df = df_shuffled.iloc[val_size:].reset_index(drop=True)
    return train_df, val_df


def build_ir_evaluator(df) -> InformationRetrievalEvaluator:
    queries: Dict[str, str] = {}
    corpus: Dict[str, str] = {}
    relevant_docs: Dict[str, Dict[str, int]] = {}

    for idx, row in enumerate(df.itertuples(index=False)):
        query_id = f"q{idx}"
        code_id = f"c{idx}"
        queries[query_id] = str(getattr(row, "query", ""))
        corpus[code_id] = str(getattr(row, "code", ""))
        relevant_docs[query_id] = {code_id: 1}

    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        mrr_at_k=[1, 5, 10],
        ndcg_at_k=[1, 5, 10],
        accuracy_at_k=[1, 5, 10],
        show_progress_bar=True,
        batch_size=1,
    )


def log_dataset_summary(train_df, val_df) -> None:
    table = Table(title="Dataset Summary")
    table.add_column("Split", justify="left", style="cyan", no_wrap=True)
    table.add_column("Samples", justify="right", style="magenta")

    table.add_row("Train", f"{len(train_df)}")
    if val_df is not None:
        table.add_row("Validation", f"{len(val_df)}")
    else:
        table.add_row("Validation", "0 (disabled)")

    console.print(table)


def save_config(config: TrainingConfig, output_dir: Path) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    config_path = output_dir / "training_config.json"
    with config_path.open("w", encoding="utf-8") as fp:
        import json

        json.dump(asdict(config), fp, indent=2)
    console.print(
        f"[bold green]Saved training configuration to:[/bold green] {config_path}"
    )


def main() -> None:
    config = parse_arguments()

    console.print("[bold magenta]Starting Qwen embedding finetuning[/bold magenta]")
    console.print(f"• Model variant: [bold]{config.model_variant}[/bold]")
    console.print(f"• Training file: [bold]{config.train_file}[/bold]")

    set_seed(config.seed)

    train_df = load_train_data(config.train_file)
    train_df = train_df.dropna(subset=["query", "code"]).reset_index(drop=True)

    if train_df.empty:
        raise RuntimeError("Training dataset is empty after dropping NA rows.")

    train_df, val_df = split_train_validation(
        train_df, config.validation_split, config.seed
    )
    log_dataset_summary(train_df, val_df)

    model = load_qwen_model(config.model_variant)
    if config.gradient_checkpointing:
        if hasattr(model, "gradient_checkpointing_enable"):
            console.print(
                "[bold cyan]Gradient checkpointing enabled for the encoder.[/bold cyan]"
            )
            model.gradient_checkpointing_enable()
        else:
            console.print(
                "[yellow]Gradient checkpointing requested but unsupported by the current model.[/yellow]"
            )
    train_loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=1)

    base_output_dir = Path(config.output_dir)
    run_output_dir = base_output_dir / config.run_name
    run_output_dir.mkdir(parents=True, exist_ok=True)
    save_config(config, run_output_dir)

    evaluator = None
    if val_df is not None and not val_df.empty:
        console.print(
            "[bold cyan]Validation enabled:[/bold cyan] InformationRetrievalEvaluator"
        )
        evaluator = build_ir_evaluator(val_df)

    train_dataset = Dataset.from_pandas(
        train_df[["query", "code"]], preserve_index=False
    )
    eval_dataset = None
    if val_df is not None and not val_df.empty:
        eval_dataset = Dataset.from_pandas(
            val_df[["query", "code"]], preserve_index=False
        )

    eval_strategy = "no"
    training_args_kwargs = {}
    if evaluator is not None:
        if config.eval_steps > 0:
            eval_strategy = "steps"
            training_args_kwargs["eval_steps"] = config.eval_steps
        else:
            eval_strategy = "epoch"

    if config.checkpoint_steps > 0:
        save_strategy = "steps"
        training_args_kwargs["save_steps"] = config.checkpoint_steps
    else:
        save_strategy = "epoch" if evaluator is not None else "no"

    if config.logging_steps is not None:
        training_args_kwargs["logging_steps"] = config.logging_steps

    if config.save_total_limit is not None:
        training_args_kwargs["save_total_limit"] = config.save_total_limit

    bitsandbytes_optimizers = {
        "adamw_bnb_8bit",
        "paged_adamw_8bit",
        "paged_adamw_32bit",
    }
    if config.optim in bitsandbytes_optimizers:
        try:
            import bitsandbytes as _bnb  # noqa: F401
        except ImportError as exc:
            raise RuntimeError(
                f"Optimizer '{config.optim}' requires the bitsandbytes package. Install it with `pip install bitsandbytes`."
            ) from exc

    training_args = SentenceTransformerTrainingArguments(
        output_dir=str(run_output_dir),
        run_name=config.run_name,
        num_train_epochs=config.epochs,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        warmup_ratio=config.warmup_ratio,
        eval_strategy=eval_strategy,
        eval_on_start=True,
        save_strategy=save_strategy,
        seed=config.seed,
        fp16=config.mixed_precision,
        bf16=flash_attn_installed,
        gradient_checkpointing=config.gradient_checkpointing,
        optim=config.optim,
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        push_to_hub=config.push_to_hub,
        hub_model_id=config.hub_model_id,
        hub_token=config.hub_token,
        hub_private_repo=config.hub_private_repo,
        report_to=["tensorboard"],
        load_best_model_at_end=evaluator is not None
        and save_strategy != "no"
        and eval_strategy != "no",
        **training_args_kwargs,
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=train_loss,
        evaluator=evaluator,
    )

    console.print("[bold green]Commencing training...[/bold green]")
    trainer.train()

    console.print("[bold green]Saving final model checkpoint...[/bold green]")
    trainer.save_model()

    if evaluator is not None:
        console.print(
            "[bold cyan]Running final evaluation on validation set...[/bold cyan]"
        )
        scores = evaluator(trainer.model, output_path=str(run_output_dir))
        for metric, score in scores.items():
            console.print(f"  • {metric}: [magenta]{score:.4f}[/magenta]")

    if config.push_to_hub:
        console.print("[bold blue]Pushing model to the Hugging Face Hub...[/bold blue]")
        trainer.push_to_hub()

    console.print(
        f"[bold green]Training complete.[/bold green] Fine-tuned model saved to [bold]{run_output_dir}[/bold]"
    )


if __name__ == "__main__":
    main()

oh, thats great ! Thank you very much

Sign up or log in to comment