Skip to content

optimizer.load_state_dict not working #1044

@saurabh-kataria

Description

@saurabh-kataria

I am unable to debug this. Torchao seems to have incompatibility with huggingface accelerate maybe or something else

My optimizer is from torchao.prototype.low_bit_optim import Adam4bit

line 173, in load_checkpoint
[rank3]: optimizer.load_state_dict(optimizer_state_dict)
[rank3]: File "/home/skatar6/.local/lib/python3.9/site-packages/torch/_compile.py", line 24, in inner
[rank3]: return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
[rank3]: File "/home/skatar6/.local/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank3]: return fn(*args, **kwargs)
[rank3]: File "/home/skatar6/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 777, in load_state_dict
[rank3]: state[param] = _cast(param, v, param_id=k, param_groups=state_dict['param_groups'])
[rank3]: File "/home/skatar6/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 764, in _cast
[rank3]: return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
[rank3]: File "/home/skatar6/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 764, in
[rank3]: return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
[rank3]: File "/home/skatar6/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 762, in _cast
[rank3]: return Optimizer.process_value_according_to_param_policy(param, value, param_id, param_groups, key)
[rank3]: File "/home/skatar6/.local/lib/python3.9/site-packages/torch/optim/optimizer.py", line 644, in process_value_according_to_param_policy
[rank3]: return value.to(dtype=param.dtype, device=param.device)
[rank3]: File "/home/skatar6/anaconda3/envs/tmp4/lib/python3.9/site-packages/torchao/utils.py", line 377, in dispatch__torch_function

[rank3]: return func(*args, **kwargs)
[rank3]: File "/home/skatar6/anaconda3/envs/tmp4/lib/python3.9/site-packages/torchao/utils.py", line 392, in dispatch__torch_dispatch

[rank3]: kwarg_types = {k: type(arg) for k, arg in kwargs}
[rank3]: File "/home/skatar6/anaconda3/envs/tmp4/lib/python3.9/site-packages/torchao/utils.py", line 392, in
[rank3]: kwarg_types = {k: type(arg) for k, arg in kwargs}
[rank3]: ValueError: too many values to unpack (expected 2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions