Skip to content

Commit 590f8fb

Browse files
SpinQuant (#983)
* SpinQuant using R2 matrices * Move Hadamard functions and matrices to separate file * Add R4 rotation * Reformat * Do not wrap Linear layers but use nn.Sequential Wrapping the Linear layers might mess with the quantization of the linear layers, so it's probably better to keep the linear layers the same and insert new layers alongside them * Add test * Fix test and do small reformat of Hadamard code * Fuse Layernorm params into linear layers This is done for pre-norm LLMs like LLaMa to make them scale-invariant (see footnote 3 in the paper). However, in the current implementation it seems to hurt performance when quantization is used. * Add R1 rotation * Add option to load pretrained R1/R2 matrices * Move Spinquant from `torchao/quantization` to `torchao/prototype/spinquant` * Move Hadamard matrices to a separate file * Move test * Minor changes * Reformat * Only enable R4 as default setting Random R1 and R2 matrices are showing worse results than just using R4, so the latter seems to be a better default option (at least for now). * Add __init__.py to spinquant folder * Do not fail if fast_hadamard_transform is not present
1 parent 76b6e36 commit 590f8fb

File tree

8 files changed

+99838
-1
lines changed

8 files changed

+99838
-1
lines changed

test/prototype/test_spinquant.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
import torch
3+
from torchao._models.llama.model import Transformer
4+
from torchao.prototype.spinquant import apply_spinquant
5+
6+
7+
def _init_model(name="7B", device="cpu", precision=torch.bfloat16):
8+
model = Transformer.from_name(name)
9+
model.to(device=device, dtype=precision)
10+
return model.eval()
11+
12+
13+
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
14+
15+
16+
@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
17+
def test_spinquant_no_quantization(device):
18+
model = _init_model(device=device)
19+
seq_len = 16
20+
batch_size = 1
21+
is_training = False
22+
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
23+
input_pos = None if is_training else torch.arange(seq_len).to(device)
24+
with torch.device(device):
25+
model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len, training=is_training)
26+
27+
with torch.no_grad():
28+
out = model(input_ids, input_pos)
29+
apply_spinquant(model)
30+
out_spinquant = model(input_ids, input_pos)
31+
32+
# Output should be the same without quantization (the rotations cancel out)
33+
# TODO: not sure if these atol/rtol are excessively large (it fails for smaller values)
34+
torch.testing.assert_close(out, out_spinquant, atol=5e-2, rtol=1e-2)
35+
36+
37+
# TODO: test GPTQ compatability?

torchao/_models/llama/eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tokenizer import get_tokenizer
3232
import time
3333
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
34+
from torchao.prototype.spinquant import apply_spinquant
3435

3536
def run_evaluation(
3637
checkpoint_path: Path,
@@ -69,6 +70,8 @@ def run_evaluation(
6970
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
7071

7172
if quantization:
73+
if "spinquant" in quantization:
74+
apply_spinquant(model)
7275
if "int8wo" in quantization:
7376
quantize_(model, int8_weight_only())
7477
if "int8dq" in quantization:
@@ -229,7 +232,7 @@ def run_evaluation(
229232
help=(
230233
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
231234
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
232-
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
235+
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
233236
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>, "
234237
"float8wo, float8dq, float8saq"
235238
),

torchao/prototype/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
1212
- [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
1313
- [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers).
14+
- [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406)
1415

1516
#### Roadmap
1617

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SpinQuant
2+
3+
Re-implementation of SpinQuant based on the official code implementation (https://github.com/facebookresearch/SpinQuant).
4+
5+
## Usage
6+
7+
Using this implementation with CUDA requires installing the Fast Hadamard Transform CUDA package, which can be done as follows:
8+
9+
```shell
10+
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
11+
```
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .spinquant import apply_spinquant

0 commit comments

Comments
 (0)