Skip to content

Commit 6e00835

Browse files
committed
Review updates
1 parent 3b3947c commit 6e00835

File tree

4 files changed

+20
-20
lines changed

4 files changed

+20
-20
lines changed

benchmarks/microbenchmarks/benchmark_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
4545

4646
# Use quantize_ to apply each quantization function to the model
4747
m_copy = deepcopy(base_model).eval().to(config.device)
48-
aoBaseConfig = string_to_config(
48+
ao_base_config = string_to_config(
4949
config.quantization,
5050
config.sparsity,
5151
high_precision_dtype=config.high_precision_dtype,
@@ -59,7 +59,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
5959
):
6060
if is_cuda:
6161
print(f"Applying {config.sparsity} sparsity to model")
62-
sparsify_(m_copy, aoBaseConfig)
62+
sparsify_(m_copy, ao_base_config)
6363
else:
6464
print(
6565
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
@@ -70,7 +70,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
7070
pass # No quantization or sparsity specified, do nothing
7171
else:
7272
print("Quantizing model....")
73-
quantize_(m_copy, aoBaseConfig)
73+
quantize_(m_copy, ao_base_config)
7474

7575
if config.use_torch_compile:
7676
print("Compiling model....")

benchmarks/microbenchmarks/benchmark_runner.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,6 @@ def get_quantization_sparsity_recipes(
8484
"""
8585
config_recipes = set()
8686

87-
# Handle edge cases
88-
if sparsity_recipes is None and quantization_recipes is None:
89-
return {("baseline", None)}
90-
if sparsity_recipes is None:
91-
return {(quant, None) for quant in quantization_recipes}
92-
if quantization_recipes is None:
93-
return {("baseline", sparse) for sparse in sparsity_recipes}
94-
9587
# Always include baseline without sparsity
9688
config_recipes.add(("baseline", None))
9789

@@ -134,8 +126,8 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig
134126
# Create all possible combinations
135127
configs = []
136128
quantization_sparsity_recipes = get_quantization_sparsity_recipes(
137-
config.get("quantization_config_recipe_names", None),
138-
config.get("sparsity_config_recipe_names", None),
129+
config.get("quantization_config_recipe_names", []),
130+
config.get("sparsity_config_recipe_names", []),
139131
)
140132
for model_param in config["model_params"]:
141133
shapes, params = get_param_combinations(model_param)

benchmarks/microbenchmarks/test/benchmark_config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Sample configuration for inference benchmarks
22
benchmark_mode: "inference"
33
quantization_config_recipe_names:
4-
# - "baseline" Will always run a baseline instatance
4+
# Will run a baseline inference for model by default, without quantization for comparison
55
- "int4wo-32"
66
- "marlin"
77
sparsity_config_recipe_names:
8-
# - "none" Will always run a without sparsity instance
8+
# Will run a baseline inference for model by default, without sparsity for comparison
99
- "semi-sparse"
1010
- "block"
1111
output_dir: "benchmarks/microbenchmarks/results"

benchmarks/microbenchmarks/test/test_benchmark_inference.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ def test_run_inference(self, mock_string_to_config):
4949
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
5050

5151
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
52-
def test_run_inference_with_sparsity(self, mock_string_to_config):
52+
def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
5353
"""Test running inference with sparsity configurations"""
5454
# Mock string_to_config to return valid configs
55+
from torchao.dtypes import MarlinSparseLayout
5556
from torchao.quantization import Int4WeightOnlyConfig
56-
from torchao.sparsity.sparse_api import (
57-
BlockSparseWeightConfig,
58-
)
5957

6058
# Test with semi-sparse config
61-
mock_string_to_config.return_value = Int4WeightOnlyConfig()
59+
mock_string_to_config.return_value = Int4WeightOnlyConfig(
60+
layout=MarlinSparseLayout()
61+
)
6262
config = BenchmarkConfig(
6363
quantization="marlin",
6464
sparsity="semi-sparse",
@@ -77,6 +77,14 @@ def test_run_inference_with_sparsity(self, mock_string_to_config):
7777
self.assertIsInstance(result, BenchmarkResult)
7878
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
7979

80+
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
81+
def test_run_inference_with_block_sparsity(self, mock_string_to_config):
82+
"""Test running inference with sparsity configurations"""
83+
# Mock string_to_config to return valid configs
84+
from torchao.sparsity.sparse_api import (
85+
BlockSparseWeightConfig,
86+
)
87+
8088
# Test with block sparsity
8189
mock_string_to_config.return_value = BlockSparseWeightConfig()
8290
config = BenchmarkConfig(

0 commit comments

Comments
 (0)