Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/source/experimental.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,44 @@ trainer = DPOTrainer(
trainer.train()
```

### GFPO

This feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning](https://huggingface.co/papers/2508.09726).

To activate GFPO in [`GFPOTrainer`]:

- set `num_remains_in_group` in [`GFPOConfig`]
- define a group filter function and set it to `group_filter_func` in [`GFPOTrainer`]. `group_filter_func` will score the `num_generations` completions and The GFPOTrainer filters groups according to their scores to get top `num_remains_in_group` completions as a new group. Model will be trained on the filtered group.

```python
# train_gfpo.py
from trl.experimental.gfpo import GFPOConfig, GFPOTrainer

# dummy group filter to scores the completions based on its indice in group
class GroupFilter:
def __call__(self, group_completions, group_rewards, **kwargs):
group_scores = []
for completions, rewards in zip(group_completions, group_rewards):
scores = [float(i) for i in range(len(completions))]
group_scores.append(scores)
return group_scores

training_args = GFPOConfig(
output_dir="Qwen3-0.6B-GFPO"
per_device_train_batch_size=4,
num_remains_in_group=2,
bf16=True,
)
trainer = GFPOTrainer(
model="Qwen/Qwen3-0.6B",
reward_funcs=...,
train_dataset=...,
args=training_args,
group_filter_func=GroupFilter(),
)
trainer.train()
```

## Usage

```python
Expand Down
6 changes: 6 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ training_args = GRPOConfig(
)
```

### Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning

**📜 Paper**: https://huggingface.co/papers/2508.09726

See [Experimental - GFPO](experimental#gfpo).

## Direct Policy Optimization

Papers relating to the [`DPOTrainer`]
Expand Down
16 changes: 16 additions & 0 deletions trl/experimental/gfpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .gfpo_config import GFPOConfig
from .gfpo_trainer import GFPOTrainer
36 changes: 36 additions & 0 deletions trl/experimental/gfpo/gfpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

from ...trainer.grpo_config import GRPOConfig as _GRPOConfig


@dataclass
class GFPOConfig(_GRPOConfig):
num_remains_in_group: Optional[int] = field(
default=None,
metadata={
"help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given."
},
)

def __post_init__(self):
super().__post_init__()

if self.num_remains_in_group is not None and self.num_remains_in_group >= self.num_generations:
raise ValueError(
f"Number remains in Group {self.num_remains_in_group} must be less than num_generations : {self.num_generations}."
)
Loading
Loading