TRL documentation

GFPO

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v1.7.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GFPO

This feature implements the GFPO algorithm to enforce concise reasoning in the model’s output generation, as proposed in the paper Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning.

Usage

To activate GFPO in GFPOTrainer:

  • set num_remains_in_group in GFPOConfig
  • define a group filter function and set it to group_filter_func in GFPOTrainer. group_filter_func will score the num_generations completions and The GFPOTrainer filters groups according to their scores to get top num_remains_in_group completions as a new group. Model will be trained on the filtered group.
# train_gfpo.py
from trl.experimental.gfpo import GFPOConfig, GFPOTrainer

# dummy group filter to scores the completions based on its indice in group
class GroupFilter:
    def __call__(self, group_completions, group_rewards, **kwargs):
        group_scores = []
        for completions, rewards in zip(group_completions, group_rewards):
            scores = [float(i) for i in range(len(completions))]
            group_scores.append(scores)
        return group_scores

training_args = GFPOConfig(
    output_dir="Qwen3-0.6B-GFPO",
    per_device_train_batch_size=4,
    num_remains_in_group=2,
    bf16=True,
)
trainer = GFPOTrainer(
    model="Qwen/Qwen3-0.6B",
    reward_funcs=...,
    train_dataset=...,
    args=training_args,
    group_filter_func=GroupFilter(),
)
trainer.train()

GFPOTrainer

class trl.experimental.gfpo.GFPOTrainer

< >

( modelreward_funcsargs = Nonetrain_dataset = Noneeval_dataset = Noneprocessing_class = Nonereward_processing_classes = Nonegroup_filter_func = Nonecallbacks = Noneoptimizers = (None, None)peft_config = None )

train

< >

( resume_from_checkpoint: str | bool | None = Nonetrial: optuna.Trial | dict[str, Any] | None = Noneignore_keys_for_eval: list[str] | None = None ) ~trainer_utils.TrainOutput

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

Returns

~trainer_utils.TrainOutput

Object containing the global step count, training loss, and metrics.

Main training entry point.

save_model

< >

( output_dir: str | None = None_internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: str | None = 'End of training'blocking: bool = Truetoken: str | None = Nonerevision: str | None = None**kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

GFPOConfig

class trl.experimental.gfpo.GFPOConfig

< >

( output_dir: str | None = Noneper_device_train_batch_size: int = 8num_train_epochs: float = 3.0max_steps: int = -1learning_rate: float = 1e-06lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear'lr_scheduler_kwargs: dict | str | None = Nonewarmup_steps: float = 0optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused'optim_args: str | None = Noneweight_decay: float = 0.0adam_beta1: float = 0.9adam_beta2: float = 0.999adam_epsilon: float = 1e-08optim_target_modules: None | str | list[str] = Nonegradient_accumulation_steps: int = 1average_tokens_across_devices: bool = Truemax_grad_norm: float = 1.0label_smoothing_factor: float = 0.0bf16: bool | None = Nonefp16: bool = Falsebf16_full_eval: bool = Falsefp16_full_eval: bool = Falsetf32: bool | None = Nonegradient_checkpointing: bool = Truegradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = Nonetorch_compile: bool = Falsetorch_compile_backend: str | None = Nonetorch_compile_mode: str | None = Noneuse_liger_kernel: bool = Falseliger_kernel_config: dict[str, bool] | None = Noneuse_cache: bool = Falseneftune_noise_alpha: float | None = Nonetorch_empty_cache_steps: int | None = Noneauto_find_batch_size: bool = Falselogging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps'logging_steps: float = 10logging_first_step: bool = Falselog_on_each_node: bool = Truelogging_nan_inf_filter: bool = Trueinclude_num_input_tokens_seen: str | bool = 'no'log_level: str = 'passive'log_level_replica: str = 'warning'disable_tqdm: bool | None = Nonereport_to: None | str | list[str] = 'none'run_name: str | None = Noneproject: str = 'huggingface'trackio_space_id: str | None = Nonetrackio_bucket_id: str | None = Nonetrackio_static_space_id: typing.Union[str, NoneType, typing.Literal[False]] = Noneeval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no'eval_steps: float | None = Noneeval_delay: float = 0per_device_eval_batch_size: int = 8prediction_loss_only: bool = Falseeval_on_start: bool = Falseeval_do_concat_batches: bool = Trueeval_use_gather_object: bool = Falseeval_accumulation_steps: int | None = Noneinclude_for_metrics: list = <factory>batch_eval_metrics: bool = Falsesave_only_model: bool = Falsesave_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps'save_steps: float = 500save_on_each_node: bool = Falsesave_total_limit: int | None = Noneenable_jit_checkpoint: bool = Falsepush_to_hub: bool = Falsehub_token: str | None = Nonehub_private_repo: bool | None = Nonehub_model_id: str | None = Nonehub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save'hub_always_push: bool = Falsehub_revision: str | None = Noneload_best_model_at_end: bool = Falsemetric_for_best_model: str | None = Nonegreater_is_better: bool | None = Noneignore_data_skip: bool = Falserestore_callback_states_from_checkpoint: bool = Falsefull_determinism: bool = Falseseed: int = 42data_seed: int | None = Noneuse_cpu: bool = Falseaccelerator_config: dict | str | None = Noneparallelism_config: accelerate.parallelism_config.ParallelismConfig | None = Nonedataloader_drop_last: bool = Falsedataloader_num_workers: int = 0dataloader_pin_memory: bool = Truedataloader_persistent_workers: bool = Falsedataloader_prefetch_factor: int | None = Noneremove_unused_columns: bool | None = Falselabel_names: list[str] | None = Nonetrain_sampling_strategy: str = 'random'length_column_name: str = 'length'ddp_find_unused_parameters: bool | None = Noneddp_bucket_cap_mb: int | None = Noneddp_broadcast_buffers: bool | None = Noneddp_static_graph: bool | None = Noneddp_backend: str | None = Noneddp_timeout: int = 1800fsdp: str | None = Nonefsdp_config: dict[str, typing.Any] | str | None = Nonedeepspeed: dict | str | None = Nonedebug: str | list[transformers.debug_utils.DebugOption] = ''skip_memory_metrics: bool = Truedo_train: bool = Falsedo_eval: bool = Falsedo_predict: bool = Falseresume_from_checkpoint: str | None = Nonewarmup_ratio: float | None = Nonelogging_dir: str | None = Nonelocal_rank: int = -1model_init_kwargs: dict[str, typing.Any] | str | None = Nonetrust_remote_code: bool = Falserouter_aux_loss_coef: float = 0.001disable_dropout: bool = Falsecast_lm_head_to_fp32: bool = Falsenum_generations: int | None = 8num_generations_eval: int | None = Nonemax_completion_length: int | None = 256ds3_gather_for_generation: bool = Trueshuffle_dataset: bool | None = Truepad_to_multiple_of: int | None = Nonegeneration_batch_size: int | None = Nonesteps_per_generation: int | None = Nonetemperature: float = 1.0top_p: float = 1.0top_k: int = 0min_p: float | None = Nonegeneration_kwargs: dict | None = Nonechat_template_kwargs: dict | None = Nonerepetition_penalty: float = 1.0cache_implementation: str | None = Noneuse_vllm: bool = Falsevllm_mode: str = 'colocate'vllm_model_impl: str = 'vllm'vllm_enable_sleep_mode: bool = Falsevllm_structured_outputs_regex: str | None = Nonevllm_server_base_url: str | None = Nonevllm_server_host: str = '0.0.0.0'vllm_server_port: int = 8000vllm_server_timeout: float = 240.0vllm_group_port: int = 51216vllm_gpu_memory_utilization: float = 0.3vllm_max_model_length: int | None = Nonevllm_tensor_parallel_size: int = 1beta: float = 0.0num_iterations: int = 1epsilon: float = 0.2delta: float | None = Noneepsilon_high: float | None = Nonesapo_temperature_neg: float = 1.05sapo_temperature_pos: float = 1.0vespo_k_pos: float = 2.0vespo_lambda_pos: float = 3.0vespo_k_neg: float = 3.0vespo_lambda_neg: float = 2.0importance_sampling_level: str = 'token'reward_weights: list[float] | None = Nonemulti_objective_aggregation: str = 'sum_then_normalize'scale_rewards: str = 'group'loss_type: str = 'dapo'mask_truncated_completions: bool = Falsesync_ref_model: bool = Falseref_model_mixup_alpha: float = 0.6ref_model_sync_steps: int = 512top_entropy_quantile: float = 1.0entropy_coef: float = 0.0use_adaptive_entropy: bool = Falseentropy_coef_min: float = 0.0entropy_coef_max: float = 1.0entropy_coef_delta: float = 0.005entropy_target: float = 0.2max_tool_calling_iterations: int | None = Nonevllm_importance_sampling_correction: bool = Truevllm_importance_sampling_mode: str = 'sequence_mask'vllm_importance_sampling_clip_max: float | None = 3.0vllm_importance_sampling_clip_min: float | None = Noneoff_policy_mask_threshold: float | None = Noneuse_bias_correction_kl: bool = Falselog_completions: bool = Falsenum_completions_to_print: int | None = Nonelog_unique_prompts: bool = Falselog_completions_hub_repo: str | None = Noneuse_transformers_continuous_batching: bool = Falsetransformers_continuous_batching_config: dict | None = Noneuse_transformers_paged: bool = Falsevllm_importance_sampling_cap: float | None = Nonenum_remains_in_group: int | None = None )

Update on GitHub