|
1 | 1 | # torchao: PyTorch Architecture Optimization
|
2 | 2 |
|
3 |
| -**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue** |
| 3 | +[](discord.gg/cudamode) |
| 4 | + |
| 5 | +This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an [issue](https://github.com/pytorch/ao/issues) |
4 | 6 |
|
5 | 7 | ## Introduction
|
6 |
| -torchao is a PyTorch native library for optimizing your models using lower precision dtypes, techniques like quantization and sparsity and performant kernels. |
| 8 | +`torchao` is a PyTorch library for quantization and sparsity. |
7 | 9 |
|
8 | 10 | ## Get Started
|
9 |
| -To try out our APIs, you can check out API examples in [quantization](./torchao/quantization) (including `autoquant`), [sparsity](./torchao/sparsity), [dtypes](./torchao/dtypes). |
10 | 11 |
|
11 |
| -## Installation |
12 |
| -**Note: this library makes liberal use of several new features in pytorch, its recommended to use it with the current nightly or latest stable version of PyTorch.** |
| 12 | +### Installation |
| 13 | +`torchao` makes liberal use of several new features in pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch. |
13 | 14 |
|
14 |
| -1. From PyPI: |
| 15 | +Stable Release |
15 | 16 | ```Shell
|
16 | 17 | pip install torchao
|
17 | 18 | ```
|
18 | 19 |
|
19 |
| -2. From Source: |
| 20 | +Nightly Release |
| 21 | +```Shell |
| 22 | +pip install torchao-nightly |
| 23 | +``` |
| 24 | + |
| 25 | +From source |
20 | 26 |
|
21 | 27 | ```Shell
|
22 |
| -git clone https://github.com/pytorch-labs/ao |
| 28 | +git clone https://github.com/pytorch/ao |
23 | 29 | cd ao
|
24 |
| -pip install -e . |
| 30 | +python setup.py develop |
| 31 | +``` |
| 32 | + |
| 33 | +### Quantization |
| 34 | + |
| 35 | +```python |
| 36 | +import torch |
| 37 | +import torchao |
| 38 | + |
| 39 | +# inductor settings which improve torch.compile performance for quantized modules |
| 40 | +torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 41 | +torch._inductor.config.use_mixed_mm = True |
| 42 | + |
| 43 | +# Plug in your model and example input |
| 44 | +model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) |
| 45 | +input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') |
| 46 | + |
| 47 | +# perform autoquantization |
| 48 | +torchao.autoquant(model, (input)) |
| 49 | + |
| 50 | +# compile the model to recover performance |
| 51 | +model = torch.compile(model, mode='max-autotune') |
| 52 | +model(input) |
25 | 53 | ```
|
26 | 54 |
|
27 |
| -## Key Features |
28 |
| -The library provides |
29 |
| -1. Support for lower precision [dtypes](./torchao/dtypes) such as nf4, uint4 that are torch.compile friendly |
30 |
| -2. [Quantization algorithms](./torchao/quantization) such as dynamic quant, smoothquant, GPTQ that run on CPU/GPU and Mobile. |
31 |
| - * Int8 dynamic activation quantization |
32 |
| - * Int8 and int4 weight-only quantization |
33 |
| - * Int8 dynamic activation quantization with int4 weight quantization |
34 |
| - * [GPTQ](https://arxiv.org/abs/2210.17323) and [Smoothquant](https://arxiv.org/abs/2211.10438) |
35 |
| - * High level `autoquant` API and kernel auto tuner targeting SOTA performance across varying model shapes on consumer/enterprise GPUs. |
36 |
| -3. [Sparsity algorithms](./torchao/sparsity) such as Wanda that help improve accuracy of sparse networks |
37 |
| -4. Integration with other PyTorch native libraries like [torchtune](https://github.com/pytorch/torchtune) and [ExecuTorch](https://github.com/pytorch/executorch) |
| 55 | +### Sparsity |
| 56 | + |
| 57 | +```python |
| 58 | +import torch |
| 59 | +from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor |
| 60 | +from torch.ao.pruning import WeightNormSparsifier |
| 61 | + |
| 62 | +# bfloat16 CUDA model |
| 63 | +model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16) |
| 64 | + |
| 65 | +# Accuracy: Finding a sparse subnetwork |
| 66 | +sparse_config = [] |
| 67 | +for name, mod in model.named_modules(): |
| 68 | + if isinstance(mod, torch.nn.Linear): |
| 69 | + sparse_config.append({"tensor_fqn": f"{name}.weight"}) |
| 70 | + |
| 71 | +sparsifier = WeightNormSparsifier(sparsity_level=1.0, |
| 72 | + sparse_block_shape=(1,4), |
| 73 | + zeros_per_block=2) |
| 74 | + |
| 75 | +# attach FakeSparsity |
| 76 | +sparsifier.prepare(model, sparse_config) |
| 77 | +sparsifier.step() |
| 78 | +sparsifier.squash_mask() |
| 79 | +# now we have dense model with sparse weights |
| 80 | + |
| 81 | +# Performance: Accelerated sparse inference |
| 82 | +for name, mod in model.named_modules(): |
| 83 | + if isinstance(mod, torch.nn.Linear): |
| 84 | + mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) |
| 85 | +``` |
| 86 | + |
| 87 | +To learn more try out our APIs, you can check out API examples in |
| 88 | +* [quantization](./torchao/quantization) |
| 89 | +* [sparsity](./torchao/sparsity) |
| 90 | +* [dtypes](./torchao/dtypes) |
| 91 | + |
| 92 | + |
| 93 | +## Supported Features |
| 94 | +1. [Quantization algorithms](./torchao/quantization) |
| 95 | + - Int4 weight-only quantization TODO: Where is this? |
38 | 96 |
|
| 97 | + - [Int8 weight-only](https://github.com/pytorch/ao/blob/main/torchao/quantization/weight_only.py) quantization |
| 98 | + - [Int4 weight-only](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/int4mm.cu) quantization |
| 99 | + - [GPTQ](https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py) and [Smoothquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/smoothquant.py) for low latency inference |
| 100 | + - High level [torchao.autoquant API](https://github.com/pytorch/ao/blob/main/torchao/quantization/autoquant.py) and [kernel autotuner](https://github.com/pytorch/ao/blob/main/torchao/kernel/autotuner.py) targeting SOTA performance across varying model shapes on consumer and enterprise GPUs |
| 101 | +2. [Sparsity algorithms](./torchao/sparsity) such as Wanda that help improve accuracy of sparse networks |
| 102 | +3. Support for lower precision [dtypes](./torchao/dtypes) such as |
| 103 | + - [nf4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) without writing custom Triton or CUDA code |
| 104 | + - [uint4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/uint4.py) |
| 105 | +4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees |
| 106 | + - [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning |
| 107 | + - [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads |
39 | 108 |
|
40 | 109 | ## Our Goals
|
41 |
| -torchao embodies PyTorch’s design philosophy [details](https://pytorch.org/docs/stable/community/design.html), especially "usability over everything else". Our vision for this repository is the following: |
42 | 110 |
|
43 |
| -* Composability: Native solutions for optimization techniques that compose with both `torch.compile` and `FSDP` |
44 |
| - * For example, for QLoRA for new dtypes support |
45 |
| -* Interoperability: Work with the rest of the PyTorch ecosystem such as torchtune, gpt-fast and ExecuTorch |
46 |
| -* Transparent Benchmarks: Regularly run performance benchmarking of our APIs across a suite of Torchbench models and across hardware backends |
| 111 | +* Composability with `torch.compile`: We rely heavily on `torch.compile` to write pure PyTorch code and codegen efficient kernels. There are however limits to what a compiler can do so we don't shy away from writing our custom CUDA/Triton kernels |
| 112 | +* Composability with `FSDP`: The new support for FSDP per parameter sharding means engineers and researchers alike can experiment with different quantization and distributed strategies concurrently. |
| 113 | +* Performance: We measure our performance on every commit using an A10G. We also regularly run performance benchmarks on the [torchbench](https://github.com/pytorch/benchmark) suite |
47 | 114 | * Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
|
48 |
| -* Infrastructure Support: Release packaging solution for kernels and a CI/CD setup that runs these kernels on different backends. |
| 115 | +* Packaging kernels should be easy: We support custom [CUDA and Triton extensions](./torchao/csrc/) so you can focus on writing your kernels and we'll ensure that they work on most operating systems and devices |
49 | 116 |
|
50 |
| -## Interoperability with PyTorch Libraries |
| 117 | +## Integrations |
51 | 118 |
|
52 |
| -torchao has been integrated with other repositories to ease usage |
| 119 | +torchao has been integrated with other libraries including |
53 | 120 |
|
54 |
| -* [torchtune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md) is integrated with 8 and 4 bit weight-only quantization techniques with and without GPTQ. |
55 |
| -* [Executorch](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#quantization) is integrated with GPTQ for both 8da4w (int8 dynamic activation, with int4 weight) and int4 weight only quantization. |
| 121 | +* [torchtune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md) leverages our 8 and 4 bit weight-only quantization techniques with optional support for GPTQ |
| 122 | +* [Executorch](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#quantization) leverages our GPTQ implementation for both 8da4w (int8 dynamic activation with int4 weight) and int4 weight-only quantization. |
| 123 | +* [HQQ](https://github.com/mobiusml/hqq/blob/master/hqq/backends/torchao.py) leverages our int4mm kernel for low latency inference |
56 | 124 |
|
57 | 125 | ## Success stories
|
58 |
| -Our kernels have has been used to achieve SOTA inference performance on |
| 126 | +Our kernels have been used to achieve SOTA inference performance on |
59 | 127 |
|
60 |
| -1. Image segmentation models with [sam-fast](pytorch.org/blog/accelerating-generative-ai) |
61 |
| -2. Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2) |
62 |
| -3. Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3) |
| 128 | +* Image segmentation models with [sam-fast](pytorch.org/blog/accelerating-generative-ai) |
| 129 | +* Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2) |
| 130 | +* Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3) |
63 | 131 |
|
64 | 132 | ## License
|
65 | 133 |
|
|
0 commit comments