TRL documentation

MiniLLM Trainer

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.25.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

MiniLLM Trainer

All_models-MiniLLM-blue

Overview

TRL supports the MiniLLM Trainer for distilling large language models into smaller ones using reverse KLD for better precision, quality, and performance, as described in the paper Knowledge Distillation of Large Language Models by Yuxian Gu, Li Dong, Furu Wei, and Minlie Huang. The abstract from the paper is the following:

Knowledge Distillation (KD) is a promising technique for reducing the high computational demand of large language models (LLMs). However, previous KD methods are primarily applied to white-box classification models or training small models to imitate black-box model APIs like ChatGPT. How to effectively distill the knowledge from white-box generative LLMs is still under-explored, which becomes more and more important with the prosperity of LLMs. In this work, we propose MiniLLM that distills smaller language models from generative larger language models. We first replace the forward Kullback-Leibler divergence (KLD) objective in the standard KD approaches with reverse KLD, which is more suitable for KD on generative language models, to prevent the student model from overestimating the low-probability regions of the teacher distribution. Then, we derive an effective optimization approach to learn this objective. Extensive experiments in the instruction-following setting show that the MiniLLM models generate more precise responses with the higher overall quality, lower exposure bias, better calibration, and higher long-text generation performance. Our method is also scalable for different model families with 120M to 13B parameters. We will release our code and model checkpoints at https://aka.ms/MiniLLM.

This post-training method was contributed by Yuxian Gu.

It is a generalized version of Think Machine Lab’s On-Policy Distillation, with the option to add distribution-level single-step distillation signals (like GKD when beta=1) and long-context reverse KLD signals. LMiniLLM=α1Exπθt=txγtttγtt[logπθ(xt+1x1..t)πteacher(xt+1x1..t)]+α2ExπθKL[πθ(x1..t)πteacher(x1..t)]. \begin{align} L_{\text{MiniLLM}}&=\alpha_1\mathbb{E}_{x\sim \pi_{\theta}}\sum_{t'=t}^{|x|}\frac{\gamma^{t'-t}}{\sum_{t'}\gamma^{t'-t}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right] \\ &+ \alpha_2\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right]. \end{align}

When α1=1 \alpha_1=1 ,α2=0 \alpha_2=0 ,γ=0 \gamma=0 , which corresponds to

from trl.experimental.minillm import MiniLLMConfig

training_args = MiniLLMConfig(
    rkl_advantage=True,
    single_step_decomposition=False,
    gamma=False
)

LMiniLLM L_{\text{MiniLLM}} becomes the on-policy KD implemented in Tinker: Ltinker=Exπθ[logπθ(xt+1x1..t)πteacher(xt+1x1..t)]. L_{\text{tinker}}=\mathbb{E}_{x\sim \pi_{\theta}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right].

Whenα1=0 \alpha_1=0 ,α2=1 \alpha_2=1 , which corresponds to

from trl.experimental.minillm import MiniLLMConfig

training_args = MiniLLMConfig(
    rkl_advantage=False,
    single_step_decomposition=True
)

LMiniLLM L_{\text{MiniLLM}} becomes the reverse KLD version of the GKD loss as in GKD Trainer: LGKD-RKL=ExπθKL[πθ(x1..t)πteacher(x1..t)]. L_{\text{GKD-RKL}}=\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right].

MiniLLMTrainer

class trl.experimental.minillm.MiniLLMTrainer

< >

( model: str | transformers.modeling_utils.PreTrainedModel teacher_model: transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module | str reward_funcs: str | transformers.modeling_utils.PreTrainedModel | collections.abc.Callable[[list, list], list[float]] | list[str | transformers.modeling_utils.PreTrainedModel | collections.abc.Callable[[list, list], list[float]]] | None = None args: trl.experimental.minillm.minillm_config.MiniLLMConfig | None = None train_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | None = None eval_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | dict[str, datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset] | None = None processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.processing_utils.ProcessorMixin | None = None reward_processing_classes: transformers.tokenization_utils_base.PreTrainedTokenizerBase | list[transformers.tokenization_utils_base.PreTrainedTokenizerBase] | None = None callbacks: list[transformers.trainer_callback.TrainerCallback] | None = None optimizers: tuple = (None, None) peft_config: PeftConfig | None = None rollout_func: collections.abc.Callable[[list[str], typing.Any, typing.Any], dict[str, typing.Any]] | None = None )

Parameters

  • model (str | PreTrainedModel) — Model to be trained. Can be either:

    • A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a path to a directory containing model weights saved using save_pretrained, e.g., './my_model_directory/'. The model is loaded using from_pretrained with the keyword arguments in args.model_init_kwargs.
    • A PreTrainedModel object. Only causal language models are supported.
  • teacher_model (PreTrainedModel | nn.Module | str) — Teacher model used for knowledge distillation. Instantiated similarly to model.
  • reward_funcs (RewardFunc | list[RewardFunc], optional) — Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward functions with the prompts and completions and sum the rewards. Can be either:

    • A single reward function, such as:

      • A string: The model ID of a pretrained model hosted inside a model repo on huggingface.co, or a path to a directory containing model weights saved using save_pretrained, e.g., './my_model_directory/'. The model is loaded using from_pretrained with num_labels=1 and the keyword arguments in args.model_init_kwargs.

      • A PreTrainedModel object: Only sequence classification models are supported.

      • A custom reward function: The function is provided with the prompts and the generated completions, plus any additional columns in the dataset. It should return a list of rewards. Custom reward functions can also return None when the reward is not applicable to those samples. This is useful for multi-task training where different reward functions apply to different types of samples. When a reward function returns None for a sample, that reward function is excluded from the reward calculation for that sample. For more details, see Using a custom reward function.

        The trainer’s state is also passed to the reward function. The trainer’s state is an instance of TrainerState and can be accessed by accessing the trainer_state argument to the reward function’s signature.

    • A list of reward functions, where each item can independently be any of the above types. Mixing different types within the list (e.g., a string model ID and a custom reward function) is allowed.

  • args (experimental.minillm.MiniLLMConfig, optional) — Configuration for this trainer. If None, a default configuration is used.
  • train_dataset (Dataset or IterableDataset) — Dataset to use for training. It must include a column "prompt". Any additional columns in the dataset is ignored. The format of the samples can be either:

    • Standard: Each sample contains plain text.
    • Conversational: Each sample contains structured messages (e.g., role and content).
  • eval_dataset (Dataset, IterableDataset or dict[str, Dataset | IterableDataset]) — Dataset to use for evaluation. It must meet the same requirements as train_dataset.
  • processing_class (PreTrainedTokenizerBase, ProcessorMixin, optional) — Processing class used to process the data. The padding side must be set to “left”. If None, the processing class is loaded from the model’s name with from_pretrained. A padding token, tokenizer.pad_token, must be set. If the processing class has not set a padding token, tokenizer.eos_token will be used as the default.
  • reward_processing_classes (PreTrainedTokenizerBase or list[PreTrainedTokenizerBase], optional) — Processing classes corresponding to the reward functions specified in reward_funcs. Can be either:

    • A single processing class: Used when reward_funcs contains only one reward function.
    • A list of processing classes: Must match the order and length of the reward functions in reward_funcs. If set to None, or if an element of the list corresponding to a PreTrainedModel is None, the tokenizer for the model is automatically loaded using from_pretrained. For elements in reward_funcs that are custom reward functions (not PreTrainedModel), the corresponding entries in reward_processing_classes are ignored.
  • callbacks (list of TrainerCallback, optional) — List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here.

    If you want to remove one of the default callbacks used, use the remove_callback method.

  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR], optional, defaults to (None, None)) — A tuple containing the optimizer and the scheduler to use. Will default to an instance of AdamW on your model and a scheduler given by get_linear_schedule_with_warmup controlled by args.
  • peft_config (PeftConfig, optional) — PEFT configuration used to wrap the model. If None, the model is not wrapped.
  • rollout_func (RolloutFunc, optional) — Function to use for generating completions. It must take prompts, args, and processing_class as parameters and return a dict with "prompt_ids", "completion_ids", and "logprobs" fields. Any other fields that are forwarded to the reward functions. This feature is experimental and may change or be removed at any time without prior notice.

Trainer for the Knowledge Distillation of Language Models (MiniLLM) method. This algorithm was initially proposed in the paper Knowledge Distillation of Large Language Models.

Example:

from datasets import load_dataset
from trl.experimental.minillm import MiniLLMTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

trainer = MiniLLMTrainer(
    model="Qwen/Qwen3-0.6B",
    teacher_model="Qwen/Qwen3-1.7B",
    train_dataset=dataset,
)
trainer.train()

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs: typing.Any )

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments used to hide deprecated arguments

Main training entry point.

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

MiniLLMConfig

class trl.experimental.minillm.MiniLLMConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: float = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: bool = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False bf16: bool | None = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: bool | None = False label_names: typing.Optional[list[str]] = None load_best_model_at_end: bool = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = None fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None parallelism_config: typing.Optional[accelerate.parallelism_config.ParallelismConfig] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch_fused' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: str = 'length' report_to: typing.Union[NoneType, str, list[str]] = None project: str = 'huggingface' trackio_space_id: typing.Optional[str] = 'trackio' ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: bool = False include_num_input_tokens_seen: typing.Union[str, bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: bool = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: bool = False average_tokens_across_devices: bool = True model_init_kwargs: dict | str | None = None disable_dropout: bool = True cast_lm_head_to_fp32: bool = False max_prompt_length: int | None = 512 num_generations: int | None = 8 max_completion_length: int | None = 256 ds3_gather_for_generation: bool = True shuffle_dataset: bool | None = True generation_batch_size: int | None = None steps_per_generation: int | None = None temperature: float = 1.0 top_p: float = 1.0 top_k: int | None = None min_p: float | None = None generation_kwargs: dict | None = None chat_template_kwargs: dict | None = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: str | None = None use_vllm: bool = False vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_enable_sleep_mode: bool = False vllm_guided_decoding_regex: str | None = None vllm_server_base_url: str | None = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_gpu_memory_utilization: float = 0.3 vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: float | None = None epsilon_high: float | None = None importance_sampling_level: str = 'token' reward_weights: list[float] | None = None scale_rewards: str = 'group' loss_type: str = 'dapo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 use_liger_loss: bool = None vllm_importance_sampling_correction: bool = True vllm_importance_sampling_cap: float = 2.0 log_completions: bool = False num_completions_to_print: int | None = None log_unique_prompts: bool = False wandb_log_unique_prompts: bool | None = None teacher_model_init_kwargs: dict[str, typing.Any] | None = None rkl_advantage: bool = True single_step_decomposition: bool = True kd_temperature: float = 1.0 gamma: float = 0.0 length_normalization: bool = True )

Parameters

  • temperature (float, optional, defaults to 0.9) — Temperature for sampling. The higher the temperature, the more random the completions.
  • lmbda (float, optional, defaults to 0.5) — Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy student-generated outputs).
  • beta (float, optional, defaults to 0.5) — Interpolation coefficient between 0.0 and 1.0 of the Generalized Jensen-Shannon Divergence loss. When beta is 0.0, the loss is the KL divergence. When beta is 1.0, the loss is the Inverse KL Divergence.
  • max_new_tokens (int, optional, defaults to 128) — Maximum number of tokens to generate per completion.
  • teacher_model_name_or_path (str, optional) — Model name or path of the teacher model. If None, the teacher model will be the same as the model being trained.
  • teacher_model_init_kwargs (dict[str, Any]], optional) — Keyword arguments to pass to AutoModelForCausalLM.from_pretrained when instantiating the teacher model from a string.
  • disable_dropout (bool, optional, defaults to True) — Whether to disable dropout in the model.
  • seq_kd (bool, optional, defaults to False) — Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output).

Configuration class for MiniLLMTrainer.

This class includes only the parameters that are specific to MiniLLM training. For a full list of training arguments, please refer to the TrainingArguments and GRPOConfig documentation.

Update on GitHub