-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix: Make sft script work when chat template is None #3995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -123,7 +126,7 @@ def main(script_args, training_args, model_args, dataset_args): | |||
# Set default chat template if needed | |||
if tokenizer.chat_template is None: | |||
# TODO: source should be passed as an argument | |||
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B") | |||
model, tokenizer, _added_tokens = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model, tokenizer, _added_tokens = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B") | |
model, tokenizer, _ = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to do this as well, but then I made it more explicit to the users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as you like... ruff etc. might complain... i think this and the Optional
should be the only change in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree with @kashif
|
||
from accelerate import logging | ||
from datasets import load_dataset | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can revert this change
if subparsers is not None: | ||
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) | ||
else: | ||
parser = TrlParser(dataclass_types) | ||
return parser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, using a consistent indent level for the return is a better practice
What does this PR do?
This pull request makes improvements to the
trl/scripts/sft.py
script, primarily focusing on import optimization, correcting lazy importing, and minor API adjustments for clarity and correctness. The most significant changes are grouped below:Import optimization and lazy loading:
AutoModelForCausalLM
at the top of the file and moved it inside the relevant code path in themain
function to load it lazily.API and function signature improvements:
_added_tokens
as the third argument becauseclone_chat_template
returns three values, not two.make_parser
function to make the type hint for thesubparsers
accurate and simplified the logic for returning the correct parser instance.Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.