Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import copy
import pytest

from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
from torchao.dtypes import MarlinSparseLayoutType
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.quantization.quant_api import int4_weight_only, quantize_
from torchao.sparsity.marlin import (
pack_to_marlin_24,
unpack_from_marlin_24,
inject_24
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
ZeroPointDomain,
MappingType,
)


class SparseMarlin24(TestCase):

def setUp(self):
super().setUp()
torch.manual_seed(0)

self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
self.model = (
nn.Sequential(
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
nn.ReLU(),
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
)
.half()
.cuda()
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_eager(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_compile(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(self.model)

self.model.forward = torch.compile(self.model.forward, fullgraph=True)
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_pack_unpack_equivalence(self):
num_bits = 4
group_size = 128
shape = (11008, 4096)
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
zero_point_dtype = torch.bfloat16
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
scale_dtype = None

w = torch.rand(shape, dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity mask
w_24, _ = inject_24(w, *w.shape)

# Quantize weights
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain)
scales = scales.reshape(-1, w_q_24.shape[1])

# Test pack/unpack equivalence
q_w_comp, packed_scales, meta = pack_to_marlin_24(
w_q_24, scales, num_bits, group_size
)
unpacked_q_w, unpacked_scales = unpack_from_marlin_24(
q_w_comp, packed_scales, meta, shape, group_size, num_bits
)

assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights"
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales"


if __name__ == "__main__":
run_tests()
24 changes: 15 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
)


MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
MNK_FACTORS = [
Expand All @@ -318,8 +319,8 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
Expand Down Expand Up @@ -374,15 +375,15 @@ def reshape_w(w):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
@pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity
Expand All @@ -391,19 +392,24 @@ def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
# Symmetric quantize
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)

# Reshape input into 2D tensor
input_2d = a_input.view(-1, a_input.shape[-1])
a_input_in, a_input_out = input_2d.shape

# Obtains reference output
output_ref = torch.matmul(a_input, w_24_ref)
output_ref = torch.matmul(input_2d, w_24_ref)
output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],))

# Packs to marlin 2:4
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
torch.cuda.synchronize()
output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],))

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
Expand Down
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1123,4 +1123,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::marlin_24_gemm", &marlin_24_gemm);
}

} // namespace torchao
} // namespace torchao
2 changes: 1 addition & 1 deletion torchao/csrc/sparse_marlin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
TORCH_LIBRARY_FRAGMENT(torchao, m) {
m.impl_abstract_pystub("torchao.ops");
m.def("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor");
}
}
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TensorCoreTiledLayoutType,
Float8LayoutType,
Float8AQTLayout,
MarlinSparseLayoutType,
)

__all__ = [
Expand All @@ -33,4 +34,5 @@
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Float8AQTLayout",
"MarlinSparseLayoutType",
]
Loading
Loading