Skip to content

Commit 1958a97

Browse files
committed
layout_tensor -> _layout
1 parent f0539d2 commit 1958a97

File tree

24 files changed

+258
-258
lines changed

24 files changed

+258
-258
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
3131
base_functions.append(int4_weight_only(group_size=32))
3232

3333
if do_sparse:
34-
base_functions.append(int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayout()))
34+
base_functions.append(int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()))
3535

3636
if is_cuda_8_9:
3737
base_functions.append(float8_weight_only())

test/dtypes/test_floatx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_to_copy_device(self, ebits, mbits):
8181
x = torch.randn(256, 64)
8282
scale = choose_qparams_affine_floatx(x, ebits, mbits)
8383
x = quantize_affine_floatx(x, scale, ebits, mbits)
84-
layout_type = FloatxTensorCoreLayout(ebits, mbits)
85-
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda()
84+
_layout = FloatxTensorCoreLayout(ebits, mbits)
85+
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, _layout).cuda()
8686
assert floatx_tensor_impl.device.type == "cuda"
8787
floatx_tensor_impl = floatx_tensor_impl.cpu()
8888
assert floatx_tensor_impl.device.type == "cpu"

test/integration/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
876876
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
877877
for groupsize in [64, 32]:
878878
for inner_k_tiles in [4, 2]:
879-
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
879+
kwargs = {"groupsize": groupsize, "_layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
880880

881881
def api(mod):
882882
kwargs_copy = kwargs.copy()
@@ -888,7 +888,7 @@ def api(mod):
888888
unwrap_tensor_subclass(mod)
889889
else:
890890
kwargs_copy["inner_k_tiles"] = inner_k_tiles
891-
del kwargs_copy["layout_type"]
891+
del kwargs_copy["_layout"]
892892
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
893893

894894
self._test_lin_weight_subclass_api_impl(

test/sparsity/test_marlin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_quant_sparse_marlin_layout_eager(self):
5050
dense_result = model_copy(self.input.bfloat16()).half()
5151

5252
# Sparse + quantized
53-
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayout()))
53+
quantize_(self.model, int4_weight_only(_layout=MarlinSparseLayout()))
5454
sparse_result = self.model(self.input)
5555

5656
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
@@ -67,7 +67,7 @@ def test_quant_sparse_marlin_layout_compile(self):
6767
dense_result = model_copy(self.input.bfloat16()).half()
6868

6969
# Sparse + quantized
70-
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayout()))
70+
quantize_(self.model, int4_weight_only(_layout=MarlinSparseLayout()))
7171
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
7272
sparse_result = self.model(self.input)
7373

test/sparsity/test_sparse_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile):
7474

7575
quantize_(
7676
model,
77-
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayout()),
77+
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
7878
)
7979
if compile:
8080
model = torch.compile(model)
@@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile):
108108
dense_result = model_copy(input.bfloat16()).half()
109109

110110
# Sparse + quantized
111-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayout()))
111+
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
112112
if compile:
113113
model = torch.compile(model)
114114
sparse_result = model(input)
@@ -190,7 +190,7 @@ def test_sparse(self, compile):
190190
quantize_(
191191
model,
192192
int8_dynamic_activation_int8_weight(
193-
layout_type=BlockSparseLayout(blocksize=64)
193+
_layout=BlockSparseLayout(blocksize=64)
194194
),
195195
)
196196
if compile:

torchao/_models/llama/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_evaluation(
9898
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
9999
if "marlin" in quantization:
100100
from torchao.dtypes import MarlinSparseLayout
101-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayout()))
101+
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
102102
if "int4wo" in quantization and "gptq" in quantization:
103103
# avoid circular imports
104104
from torchao._models._eval import InputRecorder

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def main(
231231
quantize_(model, int4_weight_only(group_size=groupsize))
232232
if "marlin" in quantization:
233233
from torchao.dtypes import MarlinSparseLayout
234-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayout()))
234+
quantize_(model, int4_weight_only(_layout=MarlinSparseLayout()))
235235
if "fp6" in quantization:
236236
quantize_(model, fpx_weight_only(3, 2))
237237
if quantization.startswith("awq"):

torchao/_models/sam/eval_combo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def mlp_only(mod, name):
315315
int8_dynamic_activation_int8_weight(),
316316
attn_only)
317317
quantize_(predictor.model.image_encoder,
318-
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayout()),
318+
int8_dynamic_activation_int8_weight(_layout=SemiSparseLayout()),
319319
mlp_lin1_only)
320320
sparsify_(predictor.model.image_encoder,
321321
semi_sparse_weight(),
@@ -330,7 +330,7 @@ def mlp_only(mod, name):
330330
quantize_(predictor.model.image_encoder,
331331
int8_dynamic_activation_int8_weight(),
332332
attn_only)
333-
quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayout()), mlp_lin1_only)
333+
quantize_(predictor.model.image_encoder, int4_weight_only(_layout=MarlinSparseLayout()), mlp_lin1_only)
334334
sparsify_(predictor.model.image_encoder,
335335
semi_sparse_weight(),
336336
mlp_lin2_only)

0 commit comments

Comments
 (0)