|
31 | 31 | from tokenizer import get_tokenizer
|
32 | 32 | import time
|
33 | 33 | from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
|
| 34 | +from torchao.prototype.spinquant import apply_spinquant |
34 | 35 |
|
35 | 36 | def run_evaluation(
|
36 | 37 | checkpoint_path: Path,
|
@@ -69,6 +70,8 @@ def run_evaluation(
|
69 | 70 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
70 | 71 |
|
71 | 72 | if quantization:
|
| 73 | + if "spinquant" in quantization: |
| 74 | + apply_spinquant(model) |
72 | 75 | if "int8wo" in quantization:
|
73 | 76 | quantize_(model, int8_weight_only())
|
74 | 77 | if "int8dq" in quantization:
|
@@ -229,7 +232,7 @@ def run_evaluation(
|
229 | 232 | help=(
|
230 | 233 | "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
|
231 | 234 | "int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
|
232 |
| - "uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, " |
| 235 | + "uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, " |
233 | 236 | "autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>, "
|
234 | 237 | "float8wo, float8dq, float8saq"
|
235 | 238 | ),
|
|
0 commit comments