Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,38 @@ We've added kv cache quantization and other features in order to enable long con

In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)

## Training

### Quantization Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md).

```python
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer

qat_quantizer = Int8DynActInt4WeightQATQuantizer()
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics
model = qat_quantizer.prepare(model)
# Insert fake quantization
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
my_model,
intx_quantization_aware_training(activation_config, weight_config),
)

# Run Training...
# Run training... (not shown)

# Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)
# Convert fake quantization to actual quantized operations
quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

## Training

### Float8

[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
Expand Down
65 changes: 65 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
from torchao.quantization.qat.embedding import (
Expand All @@ -42,6 +43,9 @@
_GenericFakeQuantize,
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.quant_primitives import (
MappingType,
TorchAODType,
Expand Down Expand Up @@ -1262,6 +1266,67 @@ def test_quantize_api_errors(self):
lambda m, _: isinstance(m, torch.nn.ReLU),
)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api_convert_path(self):
"""
Test that the following:

quantize_(model, intx_quantization_aware_training(...))
quantize_(model, from_intx_quantization_aware_training(...))
quantize_(model, int8_dynamic_activation_int4_weight())

can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
"""
from torchao.quantization.qat import (
Int8DynActInt4WeightQATQuantizer,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
baseline_model = copy.deepcopy(m)

# Baseline prepare
baseline_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
baseline_model = baseline_quantizer.prepare(baseline_model)

# quantize_ prepare
activation_config = FakeQuantizeConfig(
torch.int8,
"per_token",
is_symmetric=False,
)
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
quantize_(
m,
intx_quantization_aware_training(activation_config, weight_config),
)

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

# Baseline convert
baseline_model = baseline_quantizer.convert(baseline_model)

# quantize_ convert
quantize_(m, from_intx_quantization_aware_training())
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))

# Compare converted values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
189 changes: 144 additions & 45 deletions torchao/quantization/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale
```

## API

torchao currently supports two QAT schemes for linear layers:
- int8 per token dynamic activations + int4 per group weights
- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)

QAT typically involves applying a transformation to your model before and after training.
In torchao, these are represented as the prepare and convert steps: (1) prepare inserts
fake quantize operations into linear layers, and (2) convert transforms the fake quantize
Expand All @@ -34,64 +28,169 @@ Between these two steps, training can proceed exactly as before.

![qat](images/qat_diagram.png)

To use QAT in torchao, apply the prepare step using the appropriate Quantizer before
training, then apply the convert step after training for inference or generation.
For example, on a single GPU:

## torchao APIs

torchao currently supports two QAT APIs, one through the [`quantize_`](https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_)
API (recommended) and one through the Quantizer classes (legacy). The `quantize_` API
allows flexible configuration of quantization settings for both activations and weights,
while the Quantizer classes each hardcode a specific quantization setting.

For example, running QAT on a single GPU:

```python
import torch
from torchtune.models.llama3 import llama3

# Set up smaller version of llama3 to fit in a single GPU
def get_model():
return llama3(
vocab_size=4096,
num_layers=16,
num_heads=16,
num_kv_heads=4,
embed_dim=2048,
max_seq_len=2048,
).cuda()

# Example training loop
def train_loop(m: torch.nn.Module):
optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = m(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
```

### quantize_ API (recommended)

The recommended way to run QAT in torchao is through the `quantize_` API:
1. **Prepare:** specify how weights and/or activations are to be quantized through
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
2. **Convert:** quantize the model using the standard post-training quantization (PTQ)
functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)

For example:


```python
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
model = get_model()

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
model,
intx_quantization_aware_training(activation_config, weight_config),
)

# train
train_loop(model)

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, from_intx_quantization_aware_training())
quantize_(model, int8_dynamic_activation_int4_weight(group_size=32))

# inference or generate
```

To fake quantize embedding in addition to linear, you can additionally call
the following with a filter function during the prepare step:

```
quantize_(
m,
intx_quantization_aware_training(weight_config=weight_config),
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
```


### Quantizer API (legacy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a deprecation plan for this API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, but will come up with one


Alternatively, torchao provides a few hardcoded quantization settings through
the following Quantizers:
- [Int8DynActInt4QATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L126) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight
- [Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training)
- [Int4WeightOnlyEmbeddingQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/embedding.py#L94) (embedding), targeting int4 per-group symmetric weight

For example:
```python
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
model = get_model()

# Smaller version of llama3 to fit in a single GPU
model = llama3(
vocab_size=4096,
num_layers=16,
num_heads=16,
num_kv_heads=4,
embed_dim=2048,
max_seq_len=2048,
).cuda()

# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
model = qat_quantizer.prepare(model)

# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
# train
train_loop(model)

# convert: transform fake quantization ops into actual quantized ops
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
model = qat_quantizer.convert(model)

# inference or generate
```

Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune)
and apply quantized-aware fine-tuning as follows:
To use multiple Quantizers in the same model for different layer types,
users can also leverage the [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242)
as follows:

```python
from torchao.quantization.qat import (
ComposableQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)

quantizer = ComposableQATQuantizer([
Int8DynActInt4WeightQATQuantizer(groupsize=group_size),
Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size),
])

# prepare + train + convert as before
model = qat_quantizer.prepare(model)
train_loop(model)
model = qat_quantizer.convert(model)
```

## torchtune integration

torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune)
to allow users to run quantized-aware fine-tuning as follows:

```
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
```

For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
torchtune also supports a [QAT + LoRA distributed training recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py)
that is 1.89x faster and uses 36.1% memory compared to vanilla QAT in our early experiments.
You can read more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700):

```
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora
```

For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).

## Evaluation Results

Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
from .embedding import (
Expand All @@ -18,4 +19,5 @@
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"intx_quantization_aware_training",
"from_intx_quantization_aware_training",
]
Loading
Loading