diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 75701c55ca..730969ba9c 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2571,6 +2571,124 @@ def forward(self, x): node_list, ) + def test_conv_padding_bn_relu(self): + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + bias_qspec = QuantizationSpec( + dtype=torch.float32, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.PlaceholderObserver, + ) + + for n in model.graph.nodes: + if ( + n.op != "call_function" + or n.target != torch.ops.aten.relu.default + ): + continue + relu_node = n + n = n.args[0] + + # Check for any of the conv operations + conv_ops = [ + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.padding, + ] + if n.op != "call_function" or n.target not in conv_ops: + continue + + conv_node = n + input_act = conv_node.args[0] + weight = conv_node.args[1] + bias = conv_node.args[2] + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + # Test cases for Conv1d, Conv2d, Conv3d + test_cases = [ + { + "conv_type": torch.nn.Conv1d, + "bn_type": torch.nn.BatchNorm1d, + "example_input": (torch.randn(1, 3, 5),), + "conv_op": torch.ops.aten.conv1d.padding, + }, + { + "conv_type": torch.nn.Conv2d, + "bn_type": torch.nn.BatchNorm2d, + "example_input": (torch.randn(1, 3, 5, 5),), + "conv_op": torch.ops.aten.conv2d.padding, + }, + { + "conv_type": torch.nn.Conv3d, + "bn_type": torch.nn.BatchNorm3d, + "example_input": (torch.randn(1, 3, 5, 5, 5),), + "conv_op": torch.ops.aten.conv3d.padding, + }, + ] + + for test_case in test_cases: + with self.subTest(conv_type=test_case["conv_type"].__name__): + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = test_case["conv_type"](3, 3, 3, padding="same") + self.bn = test_case["bn_type"](3) + + def forward(self, x): + return torch.nn.functional.relu(self.bn(self.conv(x))) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + test_case["conv_op"], + torch.ops.aten.relu.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + + model = M().eval() + self._test_quantizer( + model, + test_case["example_input"], + BackendAQuantizer(), + node_occurrence, + node_list, + ) + def test_multi_users_without_output_observer(self): """ Test the case in which a node is used by multiple users, diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index ad5c0ae179..dc5f802fb8 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -625,8 +625,11 @@ def _is_conv_node(n: Node): """ return n.op == "call_function" and n.target in [ torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, ]