Skip to content

FSDP Sharded peft classifer saving broken #3732

@ojh31

Description

@ojh31

System Info

- `Accelerate` version: 1.7.0
- Platform: Linux-5.15.0-151-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /opt/venv/bin/accelerate
- Python version: 3.11.13
- Numpy version: 2.2.6
- PyTorch version: 2.7.1+cu126
- PyTorch accelerator: CUDA
- System RAM: 1003.13 GB
- GPU type: NVIDIA H100 80GB HBM3
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

accelerate_config.yaml

compute_environment: LOCAL_MACHINE
debug: false
# We want FSDP to shard model parameters between devices.
distributed_type: FSDP
downcast_bf16: "no"
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: "bf16"
num_machines: 1
# We overwrite this with a CLI argument
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

test_save_fsdp_model.py :

# run with accelerate launch --config_file accelerate_config.yaml test_save_fsdp_model.py 

import torch
from peft import LoraConfig, get_peft_model
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
from accelerate.utils.fsdp_utils import save_fsdp_model
from accelerate import Accelerator



def main():
    model_name = "meta-llama/Llama-3.1-8B-Instruct"
    output_dir = "/workspace/outputs/test_reward"

    # Load the tokenizer and model for sequence classification
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=1,  # For reward modeling
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        use_cache=False,
        device_map=None
    )
    peft_config = LoraConfig(
        r=64,
        lora_alpha=64*2,
        bias="none",
        task_type="SEQ_CLS",
        target_modules="all-linear",
    )
    model = get_peft_model(model, peft_config)
    accelerator = Accelerator()
    if getattr(accelerator.state, "fsdp_plugin", None):
        accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) # type: ignore

    # Add padding token if it doesn't exist
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        
    model = accelerator.prepare(model)
    output_dir = "./test-reward-checkpoint-on-train-begin"
    save_fsdp_model(
        accelerator.state.fsdp_plugin, accelerator, model, output_dir, adapter_only=True
    ) # Errors with message below
    
    #     'modules_to_save.default.weight'
    #   File "/opt/venv/lib/python3.11/site-packages/peft/utils/other.py", line 569, in <dictcomp>
    #     k: state_dict[f"modules_to_save.{adapter_name}.{k}"]
    #        ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #   File "/opt/venv/lib/python3.11/site-packages/peft/utils/other.py", line 568, in adapter_state_dict
    #     k: state_dict[f"modules_to_save.{adapter_name}.{k}"]

    #             for k in self.modules_to_save[adapter_name].state_dict()

    #         }
    #      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #   File "/opt/venv/lib/python3.11/site-packages/peft/utils/save_and_load.py", line 208, in get_peft_model_state_dict
    #     {f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()}
    #                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #   File "/opt/venv/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 58, in _get_model_state_dict
    #     return get_peft_model_state_dict(model, adapter_name=model.active_adapter)
    #            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #   File "/opt/venv/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 126, in save_fsdp_model
    #     state_dict = _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)
    #                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #   File "/workspace/scripts/test_reward.py", line 73, in main
    #   File "/workspace/scripts/test_reward.py", line 81, in <module>
    # KeyError: 'modules_to_save.default.weight'

    print("Script completed.")


if __name__ == "__main__":
    main() 

Expected behavior

The adapter should be saved to adapter_model.safetensors. Instead it errors with KeyError: 'modules_to_save.default.weight'. The problem is that get_peft_model_state_dict compares names from model.named_modules() and model.state_dict() using startswith and this fails due to a mismatch in whether the name starts with _fsdp_wrapped_module.

get_peft_model_state_dict should be called on the unwrapped model like it is when using save_pretrained

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions