|
33 | 33 |
|
34 | 34 |
|
35 | 35 | class GFPOTrainer(_GRPOTrainer):
|
36 |
| - """ |
37 |
| - GFPO proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise |
38 |
| - Reasoning](https://www.arxiv.org/abs/2508.09726) is aimed to train a LLM that demonstrates efficient COT (Chain of |
39 |
| - Thought) without significant performance degradation. To activate GFPO in GRPOTrainer: |
40 |
| - - set `num_remains_in_group` in [`GRPOConfig`] |
41 |
| - - define a group filter function and set it to `group_filter_func` in [`GRPOTrainer`]. `group_filter_func` will |
42 |
| - score the `num_generations` completions and filter the group to get top `num_remains_in_group` completions as a |
43 |
| - new group. Model will be trained on the filtered group. |
44 |
| -
|
45 |
| - Example: |
46 |
| -
|
47 |
| - ```python |
48 |
| - # train_grpo.py |
49 |
| - from trl.experimental.gfpo import GFPOConfig, GFPOTrainer |
50 |
| -
|
51 |
| - # dummy group filter to scores the completions based on its indice in group |
52 |
| - class GroupFilter: |
53 |
| - def __call__(self, group_completions, group_rewards, **kwargs): |
54 |
| - group_scores = [] |
55 |
| - for completions, rewards in zip(group_completions, group_rewards): |
56 |
| - scores = [float(i) for i in range(len(completions))] |
57 |
| - group_scores.append(scores) |
58 |
| - return group_scores |
59 |
| -
|
60 |
| - training_args = GFPOConfig( |
61 |
| - output_dir="Qwen3-0.6B-GFPO" |
62 |
| - per_device_train_batch_size=4, |
63 |
| - num_remains_in_group=2, |
64 |
| - bf16=True, |
65 |
| - ) |
66 |
| - trainer = GFPOTrainer( |
67 |
| - model="Qwen/Qwen3-0.6B", |
68 |
| - reward_funcs=..., |
69 |
| - train_dataset=..., |
70 |
| - args=training_args, |
71 |
| - group_filter_func=GroupFilter(), |
72 |
| - ) |
73 |
| - trainer.train() |
74 |
| - ``` |
75 |
| -
|
76 |
| - Args: |
77 |
| - group_filter_func (`GroupFilterFunc`, *optional*, defaults to `None`): |
78 |
| - Group filter function to filter the group before GRPO, group_filter_func should be not None when |
79 |
| - `num_remains_in_group` is given. |
80 |
| - """ |
81 |
| - |
82 | 36 | def __init__(
|
83 | 37 | self,
|
84 | 38 | model,
|
|
0 commit comments