Skip to content
Closed

Dft #3960

Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
41f67a2
add entropy loss
1485840691 Jun 21, 2025
619b598
add entropy loss to metrics
1485840691 Jun 22, 2025
bfad506
Merge branch 'main' of https://github.com/1485840691/trl
1485840691 Jun 22, 2025
07a59c9
re-format
1485840691 Jun 22, 2025
c4b2eee
use F
1485840691 Jun 22, 2025
93e049d
Output alignment
1485840691 Jun 23, 2025
af13140
merge commits
1485840691 Jun 27, 2025
42d030a
ent coef not equal 0
1485840691 Jun 27, 2025
d45f0fa
fix format
1485840691 Jun 27, 2025
2302f48
Merge branch 'main' into main
1485840691 Jun 27, 2025
c729673
fix ent loss log
1485840691 Jun 29, 2025
03ad8da
fix mode
1485840691 Jun 29, 2025
7b3c95c
Merge branch 'main' of https://github.com/1485840691/trl
1485840691 Jun 29, 2025
91da19e
Merge branch 'main' into main
1485840691 Jul 1, 2025
1b30552
update based on review
1485840691 Jul 3, 2025
f7b4f3c
Merge branch 'main' of https://github.com/1485840691/trl
1485840691 Jul 3, 2025
e73f820
Merge branch 'main' into main
1485840691 Jul 4, 2025
b022c79
adaptive entropy control
1485840691 Jul 4, 2025
97de806
adaptive entropy control update
1485840691 Jul 4, 2025
a99244c
adaptive entropy control update fmt
1485840691 Jul 4, 2025
7641827
Merge pull request #1 from 1485840691/tgt_ent
1485840691 Jul 5, 2025
32d5c7c
refactor loss in grpo
1485840691 Jul 7, 2025
872e2a4
merge master
1485840691 Jul 7, 2025
0ef57fc
Merge branch 'main' into main
1485840691 Jul 9, 2025
70bf9d1
Update comment on dynamic ent control
1485840691 Jul 9, 2025
6a531ba
Merge branch 'main' of https://github.com/1485840691/trl
1485840691 Jul 9, 2025
2fbf4da
Merge branch 'main' into main
1485840691 Jul 15, 2025
44c342b
Merge branch 'main' into main
1485840691 Jul 16, 2025
5b70fee
Merge branch 'main' into main
1485840691 Jul 24, 2025
7c05fb6
Merge branch 'main' into main
1485840691 Jul 25, 2025
72fa76b
update based on feedback
1485840691 Jul 25, 2025
b3acd9b
Merge branch 'main' into main
1485840691 Jul 25, 2025
88fa118
Update trl/trainer/grpo_config.py
1485840691 Jul 28, 2025
776cdd2
Update trl/trainer/grpo_config.py
1485840691 Jul 28, 2025
7e97263
Update trl/trainer/grpo_config.py
1485840691 Jul 28, 2025
8c9cd01
Update trl/trainer/grpo_config.py
1485840691 Jul 28, 2025
f1e1da6
Update trl/trainer/grpo_config.py
1485840691 Jul 28, 2025
340f711
Update trl/trainer/grpo_config.py
1485840691 Jul 28, 2025
4c85a76
Update trl/trainer/grpo_config.py
1485840691 Jul 28, 2025
ade3a3e
Update trl/trainer/grpo_trainer.py
1485840691 Jul 28, 2025
51f8ca9
Merge branch 'main' into main
1485840691 Jul 28, 2025
eded222
add test and update based on feedback
1485840691 Jul 28, 2025
e4ace46
Merge branch 'main' into main
1485840691 Jul 29, 2025
79c4ee1
update based on feedback
1485840691 Jul 29, 2025
d77795f
Merge branch 'main' into main
1485840691 Jul 31, 2025
8c08682
sync update entropy coef
1485840691 Jul 31, 2025
ab6ede4
Merge branch 'main' of https://github.com/1485840691/trl
1485840691 Jul 31, 2025
dde326f
Merge branch 'main' into main
1485840691 Aug 1, 2025
ce8bf67
change coef collective
LeonEricsson Aug 2, 2025
22f2fa9
nits
LeonEricsson Aug 2, 2025
df32688
separete entropy coefficient from the coefficient that is applied in …
LeonEricsson Aug 5, 2025
4d28df5
update tests
LeonEricsson Aug 5, 2025
8fe1d94
Merge branch 'main' into main
LeonEricsson Aug 5, 2025
adc5bca
Merge branch 'main' into main
1485840691 Aug 5, 2025
55b2e83
Merge branch 'main' into main
1485840691 Aug 6, 2025
3d94cd7
Merge branch 'main' into main
1485840691 Aug 7, 2025
736bb60
Merge branch 'main' into main
1485840691 Aug 12, 2025
0278e69
Merge branch 'main' into main
LeonEricsson Aug 12, 2025
94afea9
Merge branch 'main' into main
1485840691 Aug 13, 2025
cb1cb85
Merge branch 'huggingface:main' into main
1485840691 Aug 14, 2025
6f8b6c4
add dynamic ft example
1485840691 Aug 21, 2025
96c89bd
align with main
1485840691 Aug 21, 2025
fb0cbc3
dft update code
1485840691 Aug 27, 2025
aba3d94
align with main
1485840691 Aug 27, 2025
94c3c78
Merge branch 'main' into dft
1485840691 Aug 27, 2025
dc1fd82
new line
1485840691 Aug 27, 2025
949d095
Merge branch 'dft' of https://github.com/1485840691/trl into dft
1485840691 Aug 27, 2025
d3cf4fc
fix format
1485840691 Aug 27, 2025
38dfafe
Merge branch 'main' into dft
1485840691 Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions trl/scripts/dft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# /// script
# dependencies = [
# "trl @ git+https://github.com/huggingface/trl.git",
# "peft",
# ]
# ///

"""
# Full training
```
python trl/scripts/dft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--eos_token '<|im_end|>' \
--eval_strategy steps \
--eval_steps 100 \
--output_dir Qwen2-0.5B-DFT \
--push_to_hub
```

# LoRA
```
python trl/scripts/dft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-4 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--eos_token '<|im_end|>' \
--eval_strategy steps \
--eval_steps 100 \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--output_dir Qwen2-0.5B-DFT \
--push_to_hub
```
"""

import argparse
import warnings
from typing import Optional

import torch
from datasets import load_dataset
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES

from trl import (
DatasetMixtureConfig,
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
clone_chat_template,
get_dataset,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)


def compute_loss_fn(outputs, labels, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs):
logits = outputs["logits"]
logits = logits.float()
vocab_size = logits.shape[-1]
# Upcast to float if we need to compute the loss to avoid potential precision issues

shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)

loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=ignore_index, reduction="none")

probs = torch.softmax(shift_logits, dim=-1)
prob_labels = torch.clamp(shift_labels, min=0)
prob_coefs = probs.gather(1, prob_labels.unsqueeze(-1)).squeeze(-1).detach()

loss = loss * prob_coefs

if num_items_in_batch is not None:
loss = loss.sum() / num_items_in_batch
else:
loss = loss.mean()

return loss


def main(script_args, training_args, model_args, dataset_args):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

# Create model
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()

if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
from transformers import AutoModelForImageTextToText

model_kwargs.pop("use_cache", None) # Image models do not support cache
model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)

# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)

# Set default chat template if needed
if tokenizer.chat_template is None:
# TODO: source should be passed as an argument
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
warnings.warn(
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
"dataset and `dataset_name` will be ignored."
)
elif dataset_args.datasets and not script_args.dataset_name:
dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
)
else:
raise ValueError("Either `datasets` or `dataset_name` must be provided.")

# Initialize the SFT trainer
trainer = SFTTrainer(
model=model,
args=training_args,
compute_loss_func=compute_loss_fn,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)

# Train the model
trainer.train()

# Save and push to Hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)


def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser


if __name__ == "__main__":
parser = make_parser()
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
return_remaining_strings=True
)
main(script_args, training_args, model_args, dataset_args)