Skip to content

Conversation

Peter-Chou
Copy link
Contributor

@Peter-Chou Peter-Chou commented Sep 1, 2025

What does this PR do?

This PR implement the GFPO in GRPOTrainer, which is proposed in the paper Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning.
GFPO is aimed to train a LLM that demonstrates efficient COT (Chain of Thought) without significant performance degradation.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@Peter-Chou Peter-Chou changed the title add GFPO feature in GRPOTrainer [GFPO]: implement GFPO in GRPOTrainer Sep 2, 2025
@qgallouedec
Copy link
Member

Thanks for the PR!

Just went through the paper. I feel like the idea can be implemented without any modification in the codebase, leveraging the fact that, when None is returned by the reward function, the subsequent sample is ignored. Dummy example where 1 out of two completions are ignored:

def math_reward_func(completions, **kwargs):
    rewards = []
    for idx, completion in enumerate(completions):
        if i%2 == 0:
            correct = check_math_solution(prompt, completion)
            reward = 1.0 if correct else -1.0
            rewards.append(reward)
        else:
            rewards.append(None)
    return rewards

can you confirm my intuition?

@Peter-Chou
Copy link
Contributor Author

Peter-Chou commented Sep 3, 2025

I believe that filtering out unsatisfied completions based on masking reward does not align with the core solution proposed by GFPO. The reason given in the paper why GFPO chose to filter the completions without changing rewards is:

In section two

A key limitation of GRPO is its reliance on a single scalar reward signal, making it difficult to jointly optimize multiple desirable response attributes, such as brevity and accuracy. This often leads to gains in accuracy at the cost of substantial response length inflation. To address this, we introduce GFPO to enable simultaneous optimization of multiple response properties.

In section three

While it may seem natural to directly encode desirable attributes such as brevity or informativeness into the scalar reward, doing so for multiple traits can be challenging, especially when correctness must already be captured.

The Idea of GFPO is data filtration

Data filtration instead serves as an implicit, flexible form of reward shaping—akin to iterative self-improvement methods that use selective sampling to amplify specific model behaviors

GFPO's metrics validate reward integrity

While GFPO is general-purpose and can accommodate various scoring metrics, our experiments specifically leverage metrics aimed at reducing response length inflation:
• Response Length: Training on short responses directly encourages brevity.
• Token Efficiency (reward/length): Training on highly token-efficient responses
encourages succinctness, but still allows longer responses if sufficiently “justified”
by proportionately higher rewards.

@qgallouedec
Copy link
Member

Sorry but I still don't get the difference. How is it different from doing this for example:

from collections import defaultdict


def reward_func(prompts, completions_ids, **kwargs):
    num_remains_in_group = 2
    rewards = [1.0] * len(prompts)  # default reward

    # Group indices by prompt
    groups = defaultdict(list)
    for idx, prompt in enumerate(prompts):
        groups[prompt].append(idx)

    # For each group, deactivate the k longest completions
    for prompt, indices in groups.items():
        # Sort indices in this group by completion length (descending)
        sorted_indices = sorted(indices, key=lambda i: len(completions_ids[i]))

        # Deactivate top-k
        for i in sorted_indices[num_remains_in_group:]:
            rewards[i] = None

    return rewards


prompts = ["P1", "P1", "P1", "P2", "P2", "P2"]
completions_ids = [
    [11, 12, 13],
    [14, 15, 16, 17], # longest in group, reward=None
    [18, 19],
    [21, 22, 23, 24, 25], # longest in group, reward=None
    [26, 27],
    [28, 29, 30]
]

print(reward_func(prompts, completions_ids))
# [1.0, None, 1.0, None, 1.0, 1.0]

@Peter-Chou
Copy link
Contributor Author

Peter-Chou commented Sep 3, 2025

For example, your rewards come from several reward functions - some are rule-based while others are model-based. You aggregate the rewards to get the overall reward for the completion, then filter the completions based on response length and token efficiency (reward/length). Is this approach feasible according to the example you provided above?

@qgallouedec
Copy link
Member

I think so, you'd just replace

rewards = [1.0] * len(prompts)

by something like

rewards = [reward_func1(p, c) + reward_func2(p, c) for p, c in zip(prompts, completions)]

The only limitation is that you the trainer won't be able to log these rewards separately, since the aggregation is made inside one reward function

@Peter-Chou
Copy link
Contributor Author

Peter-Chou commented Sep 3, 2025

I have some reservations about your approach.
When consolidating multiple reward functions (including model-based ones) into a single reward_function, the model (like PretrainedModel) would be managed within the reward_function, along with preprocessing components (such as AutoTokenizer) for reward calculation.
This implementation does not fully leverage the current procedure.

for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
):

even when reward are set to None, the next aggregation would change reward to zero.

rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

Relative advantages will still be computed based on both satisfied and filtered completions, which will significantly affect the mean and standard deviation of rewards.
For example, if you mask reward and get top 2 out of 8 completions, rewards in group would look like this: [0., 0., 2., 0., 0., 0., 6., 0.].
(but if you filter and get top 2 completions, rewards in the new group should look like this: [2., 6.])

Ultimately, completions that should be filtered out in the first place will still undergo both forward and backward propagation. In effect, you're training the model on all completions (good completions with right reward, filtered completions with zero reward) and the relative advantages are not right.

# Compute the per_token_logps and the entropy at each position in the completion
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
model,
input_ids,
attention_mask,
logits_to_keep,
compute_entropy=True,
pixel_values=inputs.get("pixel_values"),
image_grid_thw=inputs.get("image_grid_thw"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
)

I believe the core concept of GFPO is to introduce data filtration as a critical new dimension, complementing reward mechanisms for better completion preference alignment.

@Peter-Chou
Copy link
Contributor Author

What do you think?@qgallouedec

@Peter-Chou
Copy link
Contributor Author

@LeonEricsson My thoughts don't quite align with @qgallouedec . Could you please take a look together with us?

@LeonEricsson
Copy link
Collaborator

LeonEricsson commented Sep 11, 2025

Relative advantages will still be computed based on both satisfied and filtered completions, which will significantly affect the mean and standard deviation of rewards.

I agree with this. None values are currently treated as 0, which skews mean/std calculations:

rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

However, believe that we can still implement GFPO as a reward function with minimal changes to the trainer.

Note: In practice, None is used when datasets have disjoint tasks with separate reward functions (e.g. a coding reward returns None on math samples). We shouldn’t break this use case.

Assuming:

def reward_length(prompts, completions, **kwargs):
    num_remains_in_group = 2
    rewards = [reward_func1(p, c) + reward_func2(p, c) for p, c in zip(prompts, completions)]

    groups = defaultdict(list)
    for idx, prompt in enumerate(prompts):
        groups[prompt].append(idx)

    for _, indices in groups.items():
        sorted_indices = sorted(indices, key=lambda i: len(completions[i]))
        for i in sorted_indices[num_remains_in_group:]:
            rewards[i] = None
    return rewards

Here, filtered samples have None rewards across all reward functions. This currently triggers a warning:

if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
row_reward_kwargs = {
key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state"
}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
logger.warning(
f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n"
"Please ensure that at least one reward function returns a valid reward."
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)

but I propose we treat this as a GFPO-specific case and filter these during advantage calculation. A boolean config flag could distinguish intentional GFPO filtering from accidental Nones (warning case).

@Peter-Chou what do you think?

@Peter-Chou
Copy link
Contributor Author

Peter-Chou commented Sep 11, 2025

@LeonEricsson Thanks for your reply, but I still think implementing the filter function by masking rewards through reward function is not the ideal approach.

For example, if the reward is derived from aggregating outputs of multiple reward models (some rule-based and some model-based), this approach would not align with the current procedure I described earlier

When consolidating multiple reward functions (including model-based ones) into a single reward_function, the model (like PretrainedModel) would be managed within the reward_function, along with preprocessing components (such as AutoTokenizer) for reward calculation. This implementation does not fully leverage the current procedure.

for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
):

You need to maintain the reward model and the relevant tokenizer within the reward function in each node (using a reward model with ZeRO-3 when DeepSpeed's ZeRO-3 is activated introduces extra complexity).
This means when user wants to activate GFPO in the current GRPO training process, they must not only consider how to score the completions for filtering purposes, but also how to integrate the current reward strategies with the filter into a single reward function—and take responsibility for initializing them outside the GRPOTrainer.

Additionally, the filter function serves as another method to incorporate preferences into completions—it can consider factors beyond just completion length, as noted in the paper. Separating the filter function from the reward function is preferable.

I don't think making the reward function similar to a Swiss Army knife is a good idea. The purpose of the reward function is both specific and straightforward.

@LeonEricsson
Copy link
Collaborator

LeonEricsson commented Sep 12, 2025

You need to maintain the reward model and the relevant tokenizer within the reward function in each node (using a reward model with ZeRO-3 when DeepSpeed's ZeRO-3 is activated introduces extra complexity). This means when user wants to activate GFPO in the current GRPO training process, they must not only consider how to score the completions for filtering purposes, but also how to integrate the current reward strategies with the filter into a single reward function—and take responsibility for initializing them outside the GRPOTrainer.

I see your point, this is a fair argument. Decoupling filtering from the reward functions seems like the sustainable solution.

@qgallouedec
Copy link
Member

Hey, we now have an experimental submodule. I think GFPO could be added to this submodule for now. See https://huggingface.co/docs/trl/main/en/experimental. Example: #3898
Please share any difficulty porting this implementation into experimental. We can modify GRPO to make it more hackable if necessary

@Peter-Chou
Copy link
Contributor Author

Peter-Chou commented Sep 13, 2025

@qgallouedec Thank you for your advice. But I wonder is it really necessary to move the GFPO implementation into the trl.experimental submodule?
A callback doesn't seem like the right approach here. Hacking the GRPOTrainer to enable the GFPO implementation to function as a plug-in component also presents a tougher challenge than the current solution.

If GFPO must be placed in trl.experimental, maybe create a copy of GRPOTrainer with GFPO implementation as trl.experimental.trainer.GFPOTrainer, along with GRPOConfig as trl.experimental.trainer.GFPOConfig?

@qgallouedec
Copy link
Member

A callback doesn't seem like the right approach here

Sorry it was probably not clear, the suggestion was not about having a callback, but rather having a code in experimental

I'd recommend something like:

# trl/experimental/gfpo/grpo_config.py
from ...trainer.grpo_config import GRPOConfig as _GRPOConfig

class GRPOConfig(_GRPOConfig):
    num_remains_in_group: Optional[int] = field(
        default=None,
        metadata={
            "help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given."
        },
    )

# trl/experimental/gfpo/grpo_trainer.py
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer

class GRPOTrainer(_GRPOTrainer):
    def __init__(self, model, reward_funcs, group_filter_func, args, train_dataset, eval_dataset, processing_class, reward_processing_classes, callbacks, optimizers, peft_config): 
        super().__init__(model, reward_funcs, args, train_dataset, eval_dataset, processing_class, reward_processing_classes, callbacks, optimizers, peft_config)
        self.group_filter_func = group_filter_func

    def _generate_and_score_completions(self, inputs):
        ...

(or trl/experimental/gfpo/gfpo_trainer.py and class GFPOTrainer(GRPOTrainer):, as you want)

Hacking the GRPOTrainer to enable the GFPO implementation to function as a plug-in component also presents a tougher challenge than the current solution.

I completely agree. The proposed solution requires copying a lot of GRPO code—but, as surprising as it sounds, that’s actually intentional: the experimental submodule is meant to stress-test how customizable TRL really is.
In the case of GFPO, it’s clearly not easy, since it involves copy-pasting large sections of GRPO.
The upside is that this is very useful for the dev team, because it shows exactly where we need to do a better job.
Moving forward, we should make TRL easier to extend so that ideas like GFPO can be implemented in just a few lines.

@Peter-Chou Peter-Chou force-pushed the gfpo branch 2 times, most recently from 961f57b to 7199f2b Compare September 14, 2025 00:44
@Peter-Chou
Copy link
Contributor Author

Peter-Chou commented Sep 14, 2025

@qgallouedec I integrated the GFPO implementation into trl.experiment by inheriting from GRPOTrainer and overriding the _generate_and_score_completions method.
I understand you intend for the trl.experimental submodule to serve as a testing ground for the development team to evaluate and refine new features before their official integration. While I fully support this approach, I must highlight a potential fragility: copying extensive code from GRPOTrainer could become problematic, given its frequent updates. A significant change might break the dependent implementation. There may be more robust alternatives in the long term, but for now, let’s proceed with your suggested method to ensure progress.

@qgallouedec
Copy link
Member

💯 agree, that why we need to refactor GRPO to avoid this.

@qgallouedec
Copy link
Member

qgallouedec commented Sep 14, 2025

Can you add a short subsection in the doc (section Experimental) as well

@Peter-Chou
Copy link
Contributor Author

Peter-Chou commented Sep 14, 2025

@qgallouedec Yes, The GFPO documentation has been added to experimental.md.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title [GFPO]: implement GFPO in GRPOTrainer 🌪️ [GFPO]: implement GFPO in GRPOTrainer Sep 18, 2025
@qgallouedec qgallouedec merged commit 10dc36d into huggingface:main Sep 18, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants