Skip to content

Conversation

rabinadk1
Copy link

What does this PR do?

This pull request makes improvements to the trl/scripts/sft.py script, primarily focusing on import optimization, correcting lazy importing, and minor API adjustments for clarity and correctness. The most significant changes are grouped below:

Import optimization and lazy loading:

  • Removed the eager import of AutoModelForCausalLM at the top of the file and moved it inside the relevant code path in the main function to load it lazily.

API and function signature improvements:

  • Added _added_tokens as the third argument because clone_chat_template returns three values, not two.
  • Refactored the make_parser function to make the type hint for the subparsers accurate and simplified the logic for returning the correct parser instance.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@@ -123,7 +126,7 @@ def main(script_args, training_args, model_args, dataset_args):
# Set default chat template if needed
if tokenizer.chat_template is None:
# TODO: source should be passed as an argument
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
model, tokenizer, _added_tokens = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model, tokenizer, _added_tokens = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
model, tokenizer, _ = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to do this as well, but then I made it more explicit to the users.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you like... ruff etc. might complain... i think this and the Optional should be the only change in this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with @kashif


from accelerate import logging
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can revert this change

Comment on lines -164 to -168
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, using a consistent indent level for the return is a better practice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants