# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import copy import functools import itertools import math import operator from typing import Any import torch from torch._dynamo.utils import counters from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.fx.node import map_arg from .. import config from ..lowering import lowerings as L, require_channels_last from ..pattern_matcher import ( Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match, stable_topological_sort, ) from ..utils import pad_listlike from .freezing_patterns import register_freezing_graph_pattern from .post_grad import register_lowering_pattern aten = torch.ops.aten prims = torch.ops.prims quantized_decomposed = torch.ops.quantized_decomposed quantized = torch.ops.quantized # Only for per tensor quant since permute may changes the channel idx _PER_TENSOR_QUANTIZE_OPS = [ quantized_decomposed.quantize_per_tensor.default, quantized_decomposed.quantize_per_tensor.tensor, ] _VIEW_OPS = [ aten.transpose.int, aten.permute.default, aten.view.default, ] """ The quantization.py file primarily incorporates passes related to quantization fusion in inductor, includes: 1. Dequant Promotion; 2. Conv/GEMM weight prepack with oneDNN Library; 3. Conv/GEMM quantization fusion with output quant node (if have); 4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is 1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. 2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 quantization. """ def _get_pattern_output_dtype(match: Match): """ Get the pattern's output dtype from node's meta Assume only 1 output node in this matched pattern. """ pattern_output_nodes = match.output_nodes() assert len(pattern_output_nodes) == 1 output_node = pattern_output_nodes[0] assert isinstance(output_node, torch.fx.Node) output_dtype = output_node.meta["val"].dtype assert output_dtype in [ torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.float8_e4m3fn, ] return output_dtype def _may_generate_pattern_with_dtype_convert( pattern, dtype=Arg(), with_dtype_convert=True, users=1 ): if with_dtype_convert: return CallFunction( prims.convert_element_type.default, pattern, dtype, _users=users, ) else: return pattern def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): if with_reshape: return CallFunction( torch.ops.aten.reshape.default, pattern, reshape_size, ) else: return pattern def _generate_linear_t_pattern( _dequant_per_channel_pattern, dtype, ): assert dtype in [torch.float32, torch.bfloat16] t_pattern = CallFunction( aten.permute.default, _may_generate_pattern_with_dtype_convert( _dequant_per_channel_pattern, KeywordArg("autocast_wgt_dtype"), dtype == torch.bfloat16, ), KeywordArg("permute_axes"), ) return t_pattern def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): # only insert to_dtype if is_bf16 is True computation_call = _may_generate_pattern_with_dtype_convert( call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users ) return unary_fusion(computation_call) def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): dequantize_per_tensor_activation_pattern = CallFunction( quantized_decomposed.dequantize_per_tensor.tensor if is_tensor_overload else quantized_decomposed.dequantize_per_tensor.default, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), KeywordArg("x_quant_min"), KeywordArg("x_quant_max"), KeywordArg("x_dq_dtype"), ) return dequantize_per_tensor_activation_pattern dequantize_per_channel_weight_pattern = CallFunction( quantized_decomposed.dequantize_per_channel.default, KeywordArg("q_weight"), KeywordArg("w_scale"), KeywordArg("w_zp"), KeywordArg("w_axis"), KeywordArg("w_quant_min"), KeywordArg("w_quant_max"), KeywordArg("w_dtype"), ) dequantize_per_channel_to_bf16_weight_pattern = ( _may_generate_pattern_with_dtype_convert( dequantize_per_channel_weight_pattern, KeywordArg("autocast_wgt_dtype"), ) ) dequantize_per_channel_clone_weight_pattern = CallFunction( aten.clone.default, dequantize_per_channel_weight_pattern, memory_format=KeywordArg("memory_format"), ) dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( aten.clone.default, dequantize_per_channel_to_bf16_weight_pattern, memory_format=KeywordArg("memory_format"), ) def get_qconv_pt2e_pattern(users=1): return CallFunction( torch.ops.onednn.qconv_pointwise.default, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), KeywordArg("packed_weight"), KeywordArg("w_scale"), KeywordArg("w_zp"), KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), KeywordArg("dilation"), KeywordArg("groups"), KeywordArg("output_scale"), KeywordArg("output_zero_point"), KeywordArg("output_dtype"), KeywordArg("postop_name"), KeywordArg("postop_args"), KeywordArg("postop_algorithm"), _users=users, ) def get_qconv2d_binary_pt2e_pattern(users=1): return CallFunction( torch.ops.onednn.qconv2d_pointwise.binary, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), KeywordArg("packed_weight"), KeywordArg("w_scale"), KeywordArg("w_zp"), KeywordArg("accum"), KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), KeywordArg("dilation"), KeywordArg("groups"), KeywordArg("output_scale"), KeywordArg("output_zero_point"), KeywordArg("output_dtype"), KeywordArg("accum_scale"), KeywordArg("accum_zero_point"), KeywordArg("binary_op_name"), KeywordArg("alpha"), KeywordArg("unary_op_name"), KeywordArg("unary_op_args"), KeywordArg("unary_op_algorithm"), _users=users, ) def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): qlinear_op = ( torch.ops.onednn.qlinear_pointwise.tensor if x_scale_zp_are_tensors else torch.ops.onednn.qlinear_pointwise.default ) return CallFunction( qlinear_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), KeywordArg("packed_weight"), KeywordArg("w_scale"), KeywordArg("w_zp"), KeywordArg("b"), KeywordArg("output_scale"), KeywordArg("output_zero_point"), KeywordArg("output_dtype"), KeywordArg("postop_name"), KeywordArg("postop_args"), KeywordArg("postop_algorithm"), _users=users, ) def get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors, users=1): qlinear_op = ( torch.ops.onednn.qlinear_pointwise.binary_tensor if x_scale_zp_are_tensors else torch.ops.onednn.qlinear_pointwise.binary ) return CallFunction( qlinear_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), KeywordArg("packed_weight"), KeywordArg("w_scale"), KeywordArg("w_zp"), KeywordArg("x_2"), KeywordArg("b"), KeywordArg("output_scale"), KeywordArg("output_zero_point"), KeywordArg("output_dtype"), KeywordArg("x2_scale"), KeywordArg("x2_zp"), KeywordArg("binary_op_name"), KeywordArg("alpha"), KeywordArg("unary_op_name"), KeywordArg("unary_op_args"), KeywordArg("unary_op_algorithm"), _users=users, ) dequantize_accum_pattern = CallFunction( quantized_decomposed.dequantize_per_tensor.default, KeywordArg("accum"), KeywordArg("accum_scale"), KeywordArg("accum_zp"), Arg(), Arg(), KeywordArg("accum_dq_dtype"), ) def generate_pattern_with_binary( binary_post_op, computation_call, extra_input_pattern, dtype_convert=False, swap_inputs=False, ): binary_pattern = ( CallFunction( binary_post_op, extra_input_pattern, computation_call, ) if swap_inputs else CallFunction( binary_post_op, computation_call, extra_input_pattern, ) ) return _may_generate_pattern_with_dtype_convert( binary_pattern, KeywordArg("convert_dtype_after_inplace_add"), dtype_convert, ) def generate_pattern_with_unary(computation_call, unary_post_op): if unary_post_op is not None: return CallFunction( unary_post_op, computation_call, ) return computation_call def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): quantized_op_output_pattern_pt2e = CallFunction( quantized_decomposed.quantize_per_tensor.default, _may_generate_pattern_with_dtype_convert( computation_call, Arg(), with_dtype_convert, ), KeywordArg("o_inv_scale"), KeywordArg("o_zp"), KeywordArg("o_qmin"), KeywordArg("o_qmax"), KeywordArg("o_dtype"), ) return quantized_op_output_pattern_pt2e def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): if kwarg_name in check_node.kwargs: actual_value = check_node.kwargs[kwarg_name] return actual_value == expected_value else: assert len(check_node.args) >= (args_index + 1) actual_value = check_node.args[args_index] return actual_value == expected_value def _is_valid_quantized_conv_optimization_pattern(): def fn(match): output_dtype = _get_pattern_output_dtype(match) if output_dtype in [torch.float32, torch.bfloat16]: # Only keep matched pattern with same output_dtype qconv_node_after_weight_prepack = filter_nodes( match.nodes, torch.ops.onednn.qconv_pointwise )[0] return _check_node_kwarg_arg_value( qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype ) return True return fn def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): return ( _is_valid_qconv_binary_optimization_pattern() if has_binary_post_op else _is_valid_quantized_conv_optimization_pattern() ) def _is_valid_qconv_lowering_pattern(): def fn(match): if len(match.nodes) != 1: return False return match.nodes[0].target in ( torch.ops.onednn.qconv_pointwise.default, torch.ops.onednn.qconv_pointwise.tensor, torch.ops.onednn.qconv2d_pointwise.binary, torch.ops.onednn.qconv2d_pointwise.binary_tensor, ) return fn def _register_quantized_conv_lowering( pattern, pass_number, computation_op, ): @register_lowering_pattern( pattern, extra_check=_is_valid_qconv_lowering_pattern(), pass_number=pass_number, ) def qconv(match: Match, *args, **kwargs): # Activation QParams x, x_scale, x_zp = ( kwargs["x"], kwargs["x_scale"], kwargs["x_zp"], ) # Weight QParams packed_weight, w_scale, w_zp = ( kwargs["packed_weight"], kwargs["w_scale"], kwargs["w_zp"], ) # Conv Params b, stride, padding, dilation, groups = ( kwargs["b"], kwargs["stride"], kwargs["padding"], kwargs["dilation"], kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] # Output QParams o_inv_scale = kwargs["output_scale"] o_zero_point = kwargs["output_zero_point"] output_dtype = kwargs["output_dtype"] # post op postop_name = kwargs["postop_name"] postop_args = kwargs["postop_args"] postop_algorithm = kwargs["postop_algorithm"] computation_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, b, stride, padding, dilation, groups, o_inv_scale, o_zero_point, output_dtype, postop_name, postop_args, postop_algorithm, ) counters["inductor"]["qconv_unary_lower_count"] += 1 counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qconv def _is_valid_quantized_linear_optimization_pattern(): def fn(match): output_dtype = _get_pattern_output_dtype(match) if output_dtype in [torch.float32, torch.bfloat16]: # Only keep matched pattern with same output_dtype qlinear_node_after_weight_prepack = filter_nodes( match.nodes, torch.ops.onednn.qlinear_pointwise )[0] return _check_node_kwarg_arg_value( qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype ) return True return fn def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False): return ( _is_valid_qlinear_binary_optimization_pattern() if has_binary_post_op else _is_valid_quantized_linear_optimization_pattern() ) def _is_valid_qlinear_lowering_pattern(): def fn(match): if len(match.nodes) != 1: return False return match.nodes[0].target in ( torch.ops.onednn.qlinear_pointwise.default, torch.ops.onednn.qlinear_pointwise.tensor, torch.ops.onednn.qlinear_pointwise.binary, torch.ops.onednn.qlinear_pointwise.binary_tensor, ) return fn def _register_quantized_linear_unary_lowering( pattern, pass_number, computation_op, ): @register_lowering_pattern( pattern, extra_check=_is_valid_qlinear_lowering_pattern(), pass_number=pass_number, ) def qlinear(match: Match, *args, **kwargs): output_dtype = _get_pattern_output_dtype(match) # Activation QParams x, x_scale, x_zp = ( kwargs["x"], kwargs["x_scale"], kwargs["x_zp"], ) # Weight QParams packed_weight, w_scale, w_zp = ( kwargs["packed_weight"], kwargs["w_scale"], kwargs["w_zp"], ) # bias b = kwargs["b"] if "b" in kwargs else None # Output QParams o_inv_scale = kwargs["output_scale"] o_zero_point = kwargs["output_zero_point"] # post op postop_name = kwargs["postop_name"] postop_args = kwargs["postop_args"] postop_algorithm = kwargs["postop_algorithm"] computation_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, b, o_inv_scale, o_zero_point, output_dtype, postop_name, postop_args, postop_algorithm, ) counters["inductor"]["qlinear_unary_lower_count"] += 1 counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qlinear def _register_quantized_linear_binary_lowering( pattern, pass_number, computation_op, ): @register_lowering_pattern( pattern, extra_check=_is_valid_qlinear_lowering_pattern(), pass_number=pass_number, ) def qlinear_binary(match: Match, *args, **kwargs): output_dtype = _get_pattern_output_dtype(match) assert output_dtype is not None # Activation QParams x, x_scale, x_zp = ( kwargs["x"], kwargs["x_scale"], kwargs["x_zp"], ) x2 = kwargs["x_2"] x2_scale = kwargs["x2_scale"] x2_zp = kwargs["x2_zp"] # Weight QParams packed_weight, w_scale, w_zp = ( kwargs["packed_weight"], kwargs["w_scale"], kwargs["w_zp"], ) # bias b = kwargs["b"] if "b" in kwargs else None # Output QParams o_inv_scale = kwargs["output_scale"] o_zero_point = kwargs["output_zero_point"] x2.realize() from .mkldnn_fusion import _can_be_inplace binary_op_name = kwargs["binary_op_name"] alpha = kwargs["alpha"] unary_op_name = kwargs["unary_op_name"] unary_op_args = kwargs["unary_op_args"] unary_op_algorithm = kwargs["unary_op_algorithm"] if binary_op_name == "sum" and not _can_be_inplace(x2): # When we enable the GEMM Template, the output of QLinear # will be reshaped from 2D back to 3D if the input is 3D. # This causes _can_be_inplace(x2) to return False if x2 happens # to be the output of QLinear in this scenario. # Change the post op from sum to binary add for this case. # Refer to test case: # test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2 binary_op_name = "add" computation_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, b, o_inv_scale, o_zero_point, output_dtype, x2_scale, x2_zp, binary_op_name, alpha, unary_op_name, unary_op_args, unary_op_algorithm, ) counters["inductor"]["qlinear_binary_lower_count"] += 1 counters["inductor"]["qlinear_binary_lower_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qlinear_binary def _is_valid_qconv_binary_optimization_pattern(): return _is_valid_quantized_op_binary_optimization_pattern( torch.ops.onednn.qconv_pointwise ) def _is_valid_qlinear_binary_optimization_pattern(): return _is_valid_quantized_op_binary_optimization_pattern( torch.ops.onednn.qlinear_pointwise, # we don't insert q-dq for extra input due to accuracy issues extra_input_from_dequant=False, ) def _is_valid_quantized_op_binary_optimization_pattern( qop, extra_input_from_dequant=True ): # Check if it's a valid Binary Pattern for qconv2d and qlinear: # * qop_pointwise should only has one users # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern # * the two inputs of binary node should have attribute "meta" and should be tensors # * the two inputs of binary node should have the same shape # * All users of the extra input in this pattern should be # ancestor nodes of the compute node, except for the binary node # connected to the compute node. def fn(match): output_dtype = _get_pattern_output_dtype(match) compute_node = filter_nodes(match.nodes, qop)[0] # qop_pointwise should only have one user if len(compute_node.users) != 1: return False binary_node_inputs = next(iter(compute_node.users)).args assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" if output_dtype in [torch.float32, torch.bfloat16]: extra_input_of_binary_node = None for arg in binary_node_inputs: if arg != compute_node: extra_input_of_binary_node = arg break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern if extra_input_from_dequant and ( (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( extra_input_of_binary_node.target != quantized_decomposed.dequantize_per_tensor.default ) ): return False # the two inputs of binary node should have attribute "meta" and should be tensors if not ( hasattr(binary_node_inputs[0], "meta") and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] ) or not ( hasattr(binary_node_inputs[1], "meta") and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] ): return False # the two inputs of binary node should have the same shape if ( binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] ): return False # All users of the extra input in this pattern should be # ancestor nodes of the compute node, except for the binary node # connected to the compute node. from .mkldnn_fusion import _get_remaining_users extra_input_of_pattern = ( match.kwargs["other"] if "other" in match.kwargs else ( match.kwargs["accum"] if (output_dtype in [torch.uint8, torch.int8]) or (not extra_input_from_dequant) else match.kwargs["accum_after_dequant"] ) ) if ( len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1 or extra_input_of_pattern == compute_node.args[0] ): return False return True return fn def _register_quantized_conv_binary_lowering( pattern, pass_number, computation_op, ): @register_lowering_pattern( pattern, extra_check=_is_valid_qconv_lowering_pattern(), pass_number=pass_number, ) def qconv_binary(match: Match, *args, **kwargs): output_dtype = _get_pattern_output_dtype(match) assert output_dtype is not None x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] accum = kwargs["accum"] accum_scale = kwargs["accum_scale"] accum_zp = kwargs["accum_zero_point"] packed_weight, w_scale, w_zp = ( kwargs["packed_weight"], kwargs["w_scale"], kwargs["w_zp"], ) b, stride, padding, dilation, groups = ( kwargs["b"], kwargs["stride"], kwargs["padding"], kwargs["dilation"], kwargs["groups"], ) # Output QParams output_scale = kwargs["output_scale"] output_zero_point = kwargs["output_zero_point"] # post ops binary_op_name = kwargs["binary_op_name"] alpha = kwargs["alpha"] unary_op_name = kwargs["unary_op_name"] unary_op_args = kwargs["unary_op_args"] unary_op_algorithm = kwargs["unary_op_algorithm"] accum.realize() from .mkldnn_fusion import _can_be_inplace assert _can_be_inplace(accum), ( "QConv Binary Inplace Fusion requires accum is not an alias or mutation." ) computation_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, accum, b, stride, padding, dilation, groups, output_scale, output_zero_point, output_dtype, accum_scale, accum_zp, binary_op_name, alpha, unary_op_name, unary_op_args, unary_op_algorithm, ) counters["inductor"]["qconv2d_binary_lower_count"] += 1 counters["inductor"]["qconv2d_binary_lower_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qconv_binary def _register_quantization_unary_lowering(): # QConv2d for users in [1, 2]: qconv_pattern = get_qconv_pt2e_pattern(users) _register_quantized_conv_lowering( qconv_pattern, 2, # pass_number torch.ops.onednn.qconv_pointwise.default, # computation_op ) # QLinear for x_scale_zp_are_tensors in (False, True): qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) computation_op = ( torch.ops.onednn.qlinear_pointwise.tensor if x_scale_zp_are_tensors else torch.ops.onednn.qlinear_pointwise.default ) _register_quantized_linear_unary_lowering( qlinear_pattern, 2, # pass_number computation_op, ) def _register_quantization_binary_lowering(): # QConv2d for users in (1, 2): qconv_pattern = get_qconv2d_binary_pt2e_pattern(users) _register_quantized_conv_binary_lowering( qconv_pattern, 2, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op ) # QLinear for x_scale_zp_are_tensors in (False, True): qlinear_pattern = get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors) computation_op = ( torch.ops.onednn.qlinear_pointwise.binary_tensor if x_scale_zp_are_tensors else torch.ops.onednn.qlinear_pointwise.binary ) _register_quantized_linear_binary_lowering( qlinear_pattern, 2, # pass_number computation_op, ) def _is_valid_quantized_maxpool2d_optimization_pattern(): def fn(match): # Only match the pattern which max_pool2d_with_indices returns value # instead of indices. get_item_node = filter_nodes(match.nodes, operator.getitem)[0] return get_item_node.args[1] == 0 return fn def _register_quantized_maxpool2d_lowering( pattern, computation_op, ): @register_lowering_pattern( pattern, extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), ) def qmaxpool2d(match: Match, *args, **kwargs): x = kwargs["x"] kernel_size = kwargs["kernel_size"] stride = kwargs["stride"] if ("stride" in kwargs) else None padding = kwargs["padding"] if ("padding" in kwargs) else 0 dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1 ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False if padding == 0: padding = [0, 0] if dilation == 1: dilation = [1, 1] if not stride: stride = kernel_size kernel_size = pad_listlike(kernel_size, 2) stride = pad_listlike(stride, 2) padding = pad_listlike(padding, 2) dilation = pad_listlike(dilation, 2) assert len(kernel_size) == 2 assert len(stride) == 2 assert len(padding) == 2 assert len(dilation) == 2 computation_args = ( x, kernel_size, stride, padding, dilation, ceil_mode, ) computation_args, _ = require_channels_last(computation_op, *computation_args) counters["inductor"]["qmaxpool2d_matcher_count"] += 1 counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qmaxpool2d def _register_quantization_maxpool2d(): # Currently, the default parameters are not in FX Graph generated by Dynamo export. # So, if user defines nn.MaxPool2d with different assignment of default parameter, # it will generate graph with different number of input nodes and hence # different pattern to be matched. # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 max_pool2d_args_list = [ [ KeywordArg("stride"), ], [ KeywordArg("stride"), KeywordArg("padding"), ], [ KeywordArg("stride"), KeywordArg("padding"), KeywordArg("dilation"), ], [ KeywordArg("stride"), KeywordArg("padding"), KeywordArg("dilation"), KeywordArg("ceil_mode"), ], ] for max_pool2d_args in max_pool2d_args_list: dequantize_maxpool2d_pattern = CallFunction( aten.max_pool2d_with_indices.default, get_dequantize_per_tensor_activation_pattern(), KeywordArg("kernel_size"), *max_pool2d_args, ) dequantize_lowmem_maxpool2d_pattern = CallFunction( prims._low_memory_max_pool_with_offsets.default, get_dequantize_per_tensor_activation_pattern(), KeywordArg("kernel_size"), *max_pool2d_args, KeywordArg("offset_dtype"), ) dequantize_maxpool2d_get_item_pattern = CallFunction( operator.getitem, dequantize_maxpool2d_pattern, Arg(), ) dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction( operator.getitem, dequantize_lowmem_maxpool2d_pattern, Arg(), ) _register_quantized_maxpool2d_lowering( generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), quantized.max_pool2d.default, ) _register_quantized_maxpool2d_lowering( generate_pattern_with_output_quant( dequantize_lowmem_maxpool2d_get_item_pattern ), quantized.max_pool2d.default, ) def _is_input_output_same_scale_zp(check_node): def fn(match): # Ensure all the inputs and output has same scale and zero point # Step 1: Check inputs/output zero point # Get dequant nodes at input dequant_nodes = filter_nodes( match.nodes, quantized_decomposed.dequantize_per_tensor.default ) zero_points = [node.args[2] for node in dequant_nodes] # Get quant nodes at output quant_nodes = filter_nodes( match.nodes, quantized_decomposed.quantize_per_tensor.default ) assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern" zero_points.append(quant_nodes[0].args[2]) if not all(zero_point == zero_points[0] for zero_point in zero_points): return False # Step 2: Check inputs/output scale scales = [node.args[1] for node in dequant_nodes] scales.append(quant_nodes[0].args[1]) if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] return False return True return fn def _register_quantized_cat_lowering( pattern, computation_op, ): @register_lowering_pattern( pattern, extra_check=_is_input_output_same_scale_zp(aten.cat.default), ) def qcat(match: Match, inputs, dim, **kwargs): # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] uint8_inputs = [input[0] for input in inputs] counters["inductor"]["qcat_matcher_count"] += 1 counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes) return L[computation_op](uint8_inputs, dim) return qcat _raw_dequantize_per_tensor_activation_pattern = CallFunction( quantized_decomposed.dequantize_per_tensor.default, Arg(), Arg(), Arg(), Arg(), Arg(), Arg(), ) def _register_quantization_cat(): dequantize_cat_pattern = CallFunction( aten.cat.default, ListOf(_raw_dequantize_per_tensor_activation_pattern), KeywordArg("dim"), ) _register_quantized_cat_lowering( generate_pattern_with_output_quant(dequantize_cat_pattern), aten.cat, ) def _register_quantized_reshape_lowering( pattern, computation_op, ): @register_lowering_pattern( pattern, extra_check=_is_input_output_same_scale_zp(aten.reshape.default), ) def qreshape(match: Match, *args, **kwargs): qx = kwargs["x"] shape = kwargs["shape"] counters["inductor"]["qreshape_matcher_count"] += 1 counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) return L[computation_op](qx, shape) return qreshape def _register_quantization_reshape(): dequantize_reshape_pattern = CallFunction( torch.ops.aten.reshape.default, get_dequantize_per_tensor_activation_pattern(), KeywordArg("shape"), ) _register_quantized_reshape_lowering( generate_pattern_with_output_quant(dequantize_reshape_pattern), aten.reshape, ) def _is_valid_concat_linear_int8_woq_optimization_pattern(): def fn(match): if not config.cpp.enable_concat_linear: return False assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales")) if not all( hasattr(match.kwargs[key], "meta") for key in ["x", "w1", "w2", "w3", "scales"] ): return False x = match.kwargs["x"].meta["val"] w1 = match.kwargs["w1"].meta["val"] w2 = match.kwargs["w2"].meta["val"] w3 = match.kwargs["w3"].meta["val"] scales = match.kwargs["scales"].meta["val"] if len(match.kwargs["scales"].meta["val"].size()) > 1: return False num_scales = match.kwargs["scales"].meta["val"].numel() w1_cols = match.kwargs["w1"].meta["val"].size()[0] w2_cols = match.kwargs["w2"].meta["val"].size()[0] w3_cols = match.kwargs["w3"].meta["val"].size()[0] # Technically, the shapes of the three weights need not be equal. # But currently, we only enable replacement in this case. if w1_cols != w2_cols or w2_cols != w3_cols: return False if 3 * w1_cols != num_scales: return False return ( # For now, we only support woq mm kernels # with x.type=bfloat16 and w.type=int8 x.dtype == torch.bfloat16 and w1.dtype == torch.int8 and w2.dtype == torch.int8 and w3.dtype == torch.int8 and scales.dtype == torch.bfloat16 # _weight_int8pack_mm kernel only supports cpu now # TODO: add cuda kernel support instead of calling mul+sum and x.device.type == "cpu" and x.device == w1.device and w1.device == w2.device and w2.device == w3.device and x.device == scales.device ) return fn def _is_valid_woq_optimization_pattern(): def fn(match): assert all(k in match.kwargs for k in ("x", "weight", "scales")) if not all( hasattr(match.kwargs[key], "meta") for key in ["x", "weight", "scales"] ): return False x = match.kwargs["x"].meta["val"] weight = match.kwargs["weight"].meta["val"] scales = match.kwargs["scales"].meta["val"] return ( # For now, we only support woq mm kernels # with x.type=bfloat16 and w.type=int8 x.dtype == torch.bfloat16 and weight.dtype == torch.int8 and scales.dtype == torch.bfloat16 # _weight_int8pack_mm kernel only supports cpu now # TODO: add cuda kernel support instead of calling mul+sum and x.device.type == "cpu" and x.device == weight.device and x.device == scales.device ) return fn def _register_concat_linear_int8_woq_lowering( pattern, computation_woq, computation_reshape ): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(), pass_number=4, ) def woq(match: Match, *args, **kwargs): x = kwargs["x"] w1 = kwargs["w1"] w2 = kwargs["w2"] w3 = kwargs["w3"] scales = kwargs["scales"] counters["inductor"]["woq_matcher_count"] += 1 counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) out_features = ( w1.meta["val"].size()[0] + w2.meta["val"].size()[0] + w3.meta["val"].size()[0] ) origin_x_size = tuple(x.meta["val"].size()) x_shape = [-1, origin_x_size[-1]] out_shape = list(origin_x_size[:-1] + (out_features,)) mm_node_of_x = None for candidate in iter(x.users.keys()): if ( candidate.target == aten.mm.default and list(candidate._input_nodes)[1].target == aten.cat.default ): mm_node_of_x = candidate break assert mm_node_of_x is not None, "unable to find mm node" _, cat_wgt_node = mm_node_of_x._input_nodes scaling_node = next(iter(mm_node_of_x.users.keys())) user_of_scaling_node = next(iter(scaling_node.users.keys())) # Some other pass is making some changes that entails # adding a node before it's used, but it can only be found when # lint is run. stable_topological_sort() is being run before lint, # so that error was not being being discovered. # We call stable_topological_sort here as a workaround. stable_topological_sort(match.graph) with match.graph.inserting_before(user_of_scaling_node): new_cat_node = match.graph.call_function( aten.cat.default, args=([w1, w2, w3], 0), ) x_reshape_node = match.graph.call_function( computation_reshape, args=(x, x_shape) ) new_woq_node = match.graph.call_function( computation_woq, args=(x_reshape_node, new_cat_node, scales), ) new_woq_node.meta = copy.copy(x.meta) output_reshape_node = match.graph.call_function( computation_reshape, args=(new_woq_node, out_shape) ) scaling_node.replace_all_uses_with(output_reshape_node) match.graph.erase_node(scaling_node) match.graph.erase_node(mm_node_of_x) match.graph.erase_node(cat_wgt_node) match.graph.lint() return woq def _register_woq_lowering(pattern, computation_woq, computation_reshape): @register_lowering_pattern( pattern, extra_check=_is_valid_woq_optimization_pattern(), ) def woq(match: Match, *args, **kwargs): x = kwargs["x"] weight = kwargs["weight"] scales = kwargs["scales"] counters["inductor"]["woq_matcher_count"] += 1 counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) out_features = weight.get_size()[0] origin_x_size = x.get_size() x_shape = [-1, origin_x_size[-1]] out_shape = origin_x_size[:-1] + [ out_features, ] func1 = L[computation_reshape](x, x_shape) func2 = L[computation_woq](func1, weight, scales) return L[computation_reshape](func2, out_shape) return woq def _register_woq_mm_int8_pattern1(): # F.linear(x, weight.to(dtype=x.dtype)) * scales # case of dispatching to mm, with x reshape _woq_pattern = CallFunction( aten.mul.Tensor, CallFunction( aten.reshape.default, CallFunction( aten.mm.default, CallFunction(aten.reshape.default, KeywordArg("x"), Arg()), CallFunction( aten.permute.default, CallFunction( prims.convert_element_type.default, KeywordArg("weight"), Arg() ), Arg(), ), ), Arg(), ), KeywordArg("scales"), ) _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) def _register_woq_mm_int8_pattern2(): # F.linear(x, weight.to(dtype=x.dtype)) * scales # case of dispatching to mm, w/o x reshape _woq_pattern = CallFunction( aten.mul.Tensor, CallFunction( aten.reshape.default, CallFunction( aten.mm.default, KeywordArg("x"), CallFunction( aten.permute.default, CallFunction( prims.convert_element_type.default, KeywordArg("weight"), Arg() ), Arg(), ), ), Arg(), ), KeywordArg("scales"), ) _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) def _register_woq_mm_int8_pattern3(): # F.linear(x, weight.to(dtype=x.dtype)) * scales # case of dispatching to bmm _woq_pattern = CallFunction( aten.mul.Tensor, CallFunction( aten.bmm.default, CallFunction(aten.expand.default, KeywordArg("x"), Arg()), CallFunction( aten.expand.default, CallFunction( aten.permute.default, CallFunction( prims.convert_element_type.default, KeywordArg("weight"), Arg() ), Arg(), ), Arg(), ), ), KeywordArg("scales"), ) _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) def _register_woq_mm_int8_pattern4(): _woq_pattern = CallFunction( aten.mul.Tensor, CallFunction( aten.mm.default, KeywordArg("x"), CallFunction( prims.convert_element_type.default, CallFunction( aten.permute.default, KeywordArg("weight"), Arg(), ), Arg(), ), ), KeywordArg("scales"), ) _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) def _register_int8_woq_concat_linear_pattern(): def _create_wgt_node(wgt_node_name: str): return CallFunction( prims.convert_element_type.default, CallFunction( aten.permute.default, KeywordArg(wgt_node_name), Arg(), ), Arg(), ) cat_wgt = CallFunction( aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1 ) _woq_pattern = CallFunction( aten.mul.Tensor, CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt), KeywordArg("scales"), ) _register_concat_linear_int8_woq_lowering( _woq_pattern, aten._weight_int8pack_mm.default, aten.reshape ) def _register_quantization_lowerings(): _register_quantization_unary_lowering() _register_quantization_binary_lowering() _register_quantization_maxpool2d() _register_quantization_cat() _register_quantization_reshape() def _register_woq_lowerings(): _register_woq_mm_int8_pattern1() _register_woq_mm_int8_pattern2() _register_woq_mm_int8_pattern3() _register_woq_mm_int8_pattern4() def _is_valid_dequant_promotion_pattern(dtype=torch.float32): def _inner(match): assert dtype in [torch.float32, torch.bfloat16] dequant_pattern_end_node = match.output_node() if dequant_pattern_end_node.target not in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, prims.convert_element_type.default, aten.reshape.default, ]: return False if dequant_pattern_end_node.target is aten.reshape.default: dequant_node = ( dequant_pattern_end_node.args[ 0 ] # pattern: linear <- reshape <- dequant if dtype == torch.float32 else dequant_pattern_end_node.args[0].args[ 0 ] # pattern: linear <- reshape <- to_bf16 <- dequant ) else: dequant_node = ( dequant_pattern_end_node # pattern: linear <- dequant if dtype == torch.float32 else dequant_pattern_end_node.args[ 0 ] # pattern: linear <- to_bf16 <- dequant ) if ( dequant_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, ] and len(list(dequant_pattern_end_node.users)) > 1 ): # If dequant pattern has more than 1 users, then do dequant promoted return True return False return _inner def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_promotion_pattern(dtype), pass_number=pass_number, ) def dequant_promotion(match: Match, *args, **kwargs): # Dequant_promotion will transform # graph 1: # quant # + - - - | - - - + # | dequant | # | / \ | # | node1 node2 | # + - | - - - | - + # quant quant # into: # graph 2: # quant # + - - / - \ - - + # |dequant dequant| # | | | | # | node1 node2 | # + - | - - - | - + # quant quant # In graph 1, the dequant node is shared by node1 and node2, # as a result, neither node1 nor node2 could form an int8 # fusion pattern. # After this transformation, the graph 2 could hit the int8 # fusion pattern: dequant-node-quant, respectively for # node1 and node2. assert dtype in [torch.float32, torch.bfloat16] def clone_to_new_node(graph, source_node, user_node): # Clone the source_node to a new node # Replace user_node's input from source_node to new_node assert source_node.op == "call_function", ( "clone_to_new_node only support node.op call_function" ) with graph.inserting_before(user_node): new_node = graph.call_function( source_node.target, args=source_node.args, kwargs=source_node.kwargs, ) new_node.meta = copy.copy(source_node.meta) user_node.replace_input_with(source_node, new_node) return new_node # Find the start node and end node of a dequant pattern # * End node should be the match.output_node() # * Start node should be the node of dequantize_per_tensor dequant_pattern_end_node = match.output_node() assert dequant_pattern_end_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, prims.convert_element_type.default, aten.reshape.default, ] # For a dequant pattern, we should expect see the node list as: # * OPT(aten.reshape.default) # * OPT(prims.convert_element_type.default) (to_bf16) # * dequantize_per_tensor def _find_first_node_in_dequant_pattern(_node): if _node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, ]: # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node else: assert len(_node.args) >= 1, ( "In in dequant pattern, each node should have more than 1 arg." ) return _find_first_node_in_dequant_pattern(_node.args[0]) dequant_pattern_start_node = _find_first_node_in_dequant_pattern( dequant_pattern_end_node ) assert dequant_pattern_start_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, ] # Clone the dequant pattern for each user node graph = match.graph user_node_list = list(dequant_pattern_end_node.users) for user_node in user_node_list[1:]: _source_node = dequant_pattern_end_node _user_node = user_node while _source_node != dequant_pattern_start_node.args[0]: _user_node = clone_to_new_node(graph, _source_node, _user_node) _source_node = _source_node.args[0] # type: ignore[assignment] counters["inductor"]["dequant_promotion_matcher_count"] += 1 counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) def _is_valid_dequant_conv_pattern(dtype, with_dtype_convert): def _inner(match): # Here we do some further check to ensure: # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. # 2. The dequant pattern has only 1 user of conv2d node. # If these conditions don't meet, we will not # insert weight prepack node into the matched pattern. conv_node = match.output_node() assert conv_node.target is aten.convolution.default input_meta_value = conv_node.args[0].meta.get("val") weight_meta_value = conv_node.args[1].meta.get("val") for meta_value in [input_meta_value, weight_meta_value]: if ( meta_value is None or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu") or meta_value.dim() not in [3, 4] ): # Only support conv1d/2d now return False assert dtype in [torch.float32, torch.bfloat16] if not with_dtype_convert: dequant_node = conv_node.args[0] else: convert_to_bf16 = conv_node.args[0] dequant_node = convert_to_bf16.args[0] if len(list(dequant_node.users)) != 1: # Ensure the dequant pattern only has 1 user # since we will delete the dequant pattern here return False return True return _inner def _register_qconv_weight_prepack_pass( pattern, pass_number, dtype=torch.float32, with_dtype_convert=False ): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype, with_dtype_convert), pass_number=pass_number, ) def qconv_weight_prepack(match: Match, *args, **kwargs): """ Match the pattern: int8 activation | dequant_per_tensor | Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight Insert weight prepack node and change the pattern to: int8 activation | onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight """ assert dtype in [torch.float32, torch.bfloat16] conv_node = match.output_node() assert conv_node.target is aten.convolution.default if not with_dtype_convert: dequant_node = conv_node.args[0] else: convert_to_bf16 = conv_node.args[0] dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] has_clone_to_channel_last_node_in_pattern = ( conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] ) clone_node = ( conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None ) if dtype == torch.float32: dequant_per_channel = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) else: weight_to_bf16_node = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] assert ( dequant_per_channel.target # type: ignore[union-attr] is quantized_decomposed.dequantize_per_channel.default ) # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], kwargs["x_zp"], kwargs["x_scale"], ) # Weight QParams qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], kwargs["w_zp"], ) # Conv Params bias, stride, padding, dilation, groups = ( kwargs["b"], kwargs["stride"], kwargs["padding"], kwargs["dilation"], kwargs["groups"], ) x_shape = qx.meta.get("tensor_meta").shape if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None graph = match.graph with graph.inserting_before(conv_node): # Insert weight prepack node and the QConv node packed_weight_inputs = ( qw, w_scale, x_scale, x_zp, stride, padding, dilation, groups, x_shape, ) packed_weight_op = torch.ops.onednn.qconv_prepack prepack_weight_node = graph.call_function( packed_weight_op, args=packed_weight_inputs ) new_args: tuple[Any, ...] = ( qx, x_scale, x_zp, prepack_weight_node, w_scale, w_zp, bias, stride, padding, dilation, groups, 1.0, # output_scale 0, # output_zero_point dtype, # output_dtype "none", # attr [], # scalars "", # algorithm ) new_conv_node = graph.call_function( torch.ops.onednn.qconv_pointwise.default, args=new_args ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) # Erase the original conv node graph.erase_node(conv_node) # Erase the dequant pattern if with_dtype_convert: graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] graph.erase_node(dequant_node) # type: ignore[arg-type] # Erase the dequant per channel pattern if clone_node is not None: graph.erase_node(clone_node) # type: ignore[arg-type] if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] graph.erase_node(dequant_per_channel) # type: ignore[arg-type] counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes ) def _generate_dequant_convolution_node_pattern( _dequant_per_channel_pattern, dtype=torch.float32, with_dtype_convert=False ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(), KeywordArg("autocast_act_dtype"), with_dtype_convert, ), _dequant_per_channel_pattern, KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), KeywordArg("dilation"), KeywordArg("is_transposed"), KeywordArg("out_padding"), KeywordArg("groups"), ) return dequant_convolution_node_pattern def _generate_qconv_weight_prepack_patterns( dtype=torch.float32, with_dtype_convert=False ): assert dtype in [torch.float32, torch.bfloat16] return ( _generate_dequant_convolution_node_pattern( dequantize_per_channel_weight_pattern if dtype == torch.float32 else dequantize_per_channel_to_bf16_weight_pattern, dtype, with_dtype_convert, ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. # Depend on some heuristics, it may or may not insert to(channel_last) node # between convolution and dequant_per_channel node _generate_dequant_convolution_node_pattern( dequantize_per_channel_clone_weight_pattern if dtype == torch.float32 else dequantize_per_channel_to_bf16_clone_weight_pattern, dtype, with_dtype_convert, ), ) def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): output_reshape_node = None if input_dim_exceeds_two: if input_contiguous: output_reshape_node = match.output_node() assert output_reshape_node.target is aten.reshape.default linear_node = output_reshape_node.args[0] else: linear_nodes = filter_nodes(match.nodes, aten.bmm.default) assert len(linear_nodes) == 1 linear_node = linear_nodes[0] else: linear_node = match.output_node() assert linear_node.target in ( aten.addmm.default, aten.mm.default, aten.bmm.default, ) return linear_node, output_reshape_node def _get_linear_dq_node( linear_node, input_index, input_dim_exceeds_two, input_contiguous, with_dtype_convert, ): act_reshape_node = None activation_to_bf16_node = None act_expand_node = None if input_dim_exceeds_two: if input_contiguous: act_reshape_node = linear_node.args[input_index] assert act_reshape_node.target is aten.reshape.default if not with_dtype_convert: # pattern: linear -> reshape -> dequant dequant_node = act_reshape_node.args[0] else: # pattern: linear -> reshape -> to_bf16 -> dequant activation_to_bf16_node = act_reshape_node.args[0] dequant_node = activation_to_bf16_node.args[0] else: # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous act_expand_node = linear_node.args[input_index] assert act_expand_node.target is aten.expand.default if not with_dtype_convert: dequant_node = act_expand_node.args[0] else: activation_to_bf16_node = act_expand_node.args[0] dequant_node = activation_to_bf16_node.args[0] else: if not with_dtype_convert: # pattern: linear -> dequant dequant_node = linear_node.args[input_index] else: # pattern: linear -> to_bf16 -> dequant activation_to_bf16_node = linear_node.args[input_index] dequant_node = activation_to_bf16_node.args[0] return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node def _is_valid_dequant_linear_pattern( dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert ): def _inner(match): # Check dequant pattern has only 1 user. ( linear_node, _, ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) input_index = 1 if linear_node.target is aten.addmm.default else 0 assert dtype in [torch.float32, torch.bfloat16] ( dequant_node, _, _, _, ) = _get_linear_dq_node( linear_node, input_index, input_dim_exceeds_two, input_contiguous, with_dtype_convert, ) assert dequant_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, ] if len(list(dequant_node.users)) != 1: # Ensure the dequant pattern only has 1 user # since we will delete the dequant pattern here return False # Extra check for bmm pattern if input_dim_exceeds_two and not input_contiguous: # Check for act # Act expand size should be exactly same as act size act_expand_size = match.kwargs["act_expand_size"] act_node = match.kwargs["x"] if not ( hasattr(act_node, "meta") and isinstance(act_node.meta.get("val", None), torch.Tensor) and (act_node.meta["val"].size() == torch.Size(act_expand_size)) ): return False # Check for wgt # wgt permute dims should be [1, 0] wgt_permute_dims = match.kwargs["permute_axes"] if wgt_permute_dims != [1, 0]: return False # Check below wgt size items: # wgt before expand should with dim 2 # Expand size should with dim 3 # Expand size[0] should same as act size[0] # Expand size[1] should same as wgt size[1] # Expand size[2] should same as wgt size[0] qweight_node = match.kwargs["q_weight"] wgt_expand_size = match.kwargs["wgt_expand_size"] if not ( hasattr(qweight_node, "meta") and isinstance(qweight_node.meta.get("val", None), torch.Tensor) and len(qweight_node.meta["val"].size()) == 2 and len(wgt_expand_size) == 3 and wgt_expand_size[0] == act_node.meta["val"].size()[0] and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] ): return False return True return _inner def _register_qlinear_weight_prepack_pass( pattern, pass_number, dtype=torch.float32, input_dim_exceeds_two=False, input_contiguous=True, with_dtype_convert=False, ): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_linear_pattern( dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert ), pass_number=pass_number, ) def qlinear_weight_prepack(match: Match, *args, **kwargs): """ Match the pattern: int8 activation | dequant_per_tensor | mm/addmm <- t <- dequant_per_channel <- int8_weight Insert weight prepack node and change the pattern to: int8 activation | onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight """ assert dtype in [torch.float32, torch.bfloat16] ( linear_node, output_reshape_node, ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) input_index = 1 if linear_node.target is aten.addmm.default else 0 weight_index = input_index + 1 ( dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node, ) = _get_linear_dq_node( linear_node, input_index, input_dim_exceeds_two, input_contiguous, with_dtype_convert, ) if input_dim_exceeds_two and not input_contiguous: wgt_expand_node = linear_node.args[weight_index] assert wgt_expand_node.target is aten.expand.default t_node = wgt_expand_node.args[0] else: t_node = linear_node.args[weight_index] if dtype == torch.float32: dequant_per_channel = t_node.args[0] else: weight_to_bf16_node = t_node.args[0] dequant_per_channel = weight_to_bf16_node.args[0] assert ( dequant_per_channel.target is quantized_decomposed.dequantize_per_channel.default ) # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], kwargs["x_zp"], kwargs["x_scale"], ) # Weight QParams qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], kwargs["w_zp"], ) # Params bias = kwargs["b"] if "b" in kwargs else None x_shape = qx.meta.get("tensor_meta").shape if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None graph = match.graph with graph.inserting_before(linear_node): # Insert weight prepack node and the qlinear node packed_weight_inputs = ( qw, x_shape, ) packed_weight_op = torch.ops.onednn.qlinear_prepack prepack_weight_node = graph.call_function( packed_weight_op, args=packed_weight_inputs ) new_args: tuple[Any, ...] = ( qx, x_scale, x_zp, prepack_weight_node, w_scale, w_zp, bias, 1.0, # output_scale 0, # output_zero_point dtype, # output_dtype "none", # post op name [], # post op args "", # post op algorithm ) Node = torch.fx.node.Node if isinstance(x_scale, Node) and isinstance(x_zp, Node): new_linear_node = graph.call_function( torch.ops.onednn.qlinear_pointwise.tensor, args=new_args ) else: new_linear_node = graph.call_function( torch.ops.onednn.qlinear_pointwise.default, args=new_args ) if input_dim_exceeds_two: if input_contiguous: output_reshape_node.replace_all_uses_with(new_linear_node) new_linear_node.meta.update(output_reshape_node.meta) else: if bias: output_add_node_for_bias = match.output_node() assert output_add_node_for_bias.target is aten.add.Tensor output_add_node_for_bias.replace_all_uses_with(new_linear_node) new_linear_node.meta.update(output_add_node_for_bias.meta) else: linear_node.replace_all_uses_with(new_linear_node) new_linear_node.meta.update(linear_node.meta) else: linear_node.replace_all_uses_with(new_linear_node) new_linear_node.meta.update(linear_node.meta) # Erase the original linear node if input_dim_exceeds_two: if input_contiguous: graph.erase_node(output_reshape_node) elif not input_contiguous and bias: graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] graph.erase_node(linear_node) if input_dim_exceeds_two: if input_contiguous: graph.erase_node(act_reshape_node) else: graph.erase_node(act_expand_node) graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] if with_dtype_convert: graph.erase_node(activation_to_bf16_node) # Erase the dequant pattern graph.erase_node(dequant_node) # Erase the dequant per channel pattern graph.erase_node(t_node) if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] graph.erase_node(dequant_per_channel) counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( match.nodes ) def _generate_dequant_linear_node_pattern( _dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False, is_tensor_overload=False, with_dtype_convert=False, ): assert dtype in [torch.float32, torch.bfloat16] t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( CallFunction( aten.addmm.default, KeywordArg("b"), _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(is_tensor_overload), KeywordArg("autocast_act_dtype"), with_dtype_convert, ), KeywordArg("act_reshape_size"), input_dim_exceeds_two, ), t_pattern, ), KeywordArg("output_reshape_size"), input_dim_exceeds_two, ) dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( CallFunction( aten.mm.default, _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(is_tensor_overload), KeywordArg("autocast_act_dtype"), with_dtype_convert, ), KeywordArg("act_reshape_size"), input_dim_exceeds_two, ), t_pattern, ), KeywordArg("output_reshape_size"), input_dim_exceeds_two, ) return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern def _generate_dequant_bmm_node_pattern( _dequant_per_channel_pattern, dtype=torch.float32, with_bias=False, is_tensor_overload=False, with_dtype_convert=False, ): # When activation of linear dim exceed 2 and not contiguous t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) assert dtype in [torch.float32, torch.bfloat16] dequant_bmm_pattern = CallFunction( aten.bmm.default, CallFunction( aten.expand.default, _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(is_tensor_overload), KeywordArg("autocast_act_dtype"), with_dtype_convert, ), KeywordArg("act_expand_size"), ), CallFunction( aten.expand.default, t_pattern, KeywordArg("wgt_expand_size"), ), ) def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): if _with_bias: return CallFunction( aten.add.Tensor, _dequant_bmm_pattern, KeywordArg("b"), ) else: return _dequant_bmm_pattern return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) def _generate_qlinear_weight_prepack_patterns( dtype=torch.float32, input_dim_exceeds_two=False, input_contiguous=True, with_bias=False, is_tensor_overload=False, with_dtype_convert=False, ): if input_dim_exceeds_two and not input_contiguous: return _generate_dequant_bmm_node_pattern( dequantize_per_channel_weight_pattern, dtype, with_bias, is_tensor_overload, with_dtype_convert, ) else: return _generate_dequant_linear_node_pattern( dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two, is_tensor_overload, with_dtype_convert, ) def _generate_linear_dynamic_fp16_pattern( _dequant_weight_pattern, input_dim_exceeds_two=False, input_contiguous=True, relu_fused=False, ): dtype = torch.float32 t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype) if input_dim_exceeds_two and not input_contiguous: # pattern is # x -> expand -> bmm (-> add) (-> relu) # w -> dequant -> permute -> expand / pattern_no_bias = CallFunction( aten.bmm.default, CallFunction( aten.expand.default, KeywordArg("x"), KeywordArg("act_expand_size"), ), CallFunction( aten.expand.default, t_pattern, KeywordArg("wgt_expand_size"), ), ) pattern_with_bias = CallFunction( aten.add.Tensor, pattern_no_bias, KeywordArg("b"), ) if relu_fused: pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias) pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias) return pattern_with_bias, pattern_no_bias x_pattern_with_reshape = _may_generate_pattern_with_reshape( KeywordArg("x"), KeywordArg("act_reshape_size"), input_dim_exceeds_two, ) dequant_linear_bias_pattern = generate_pattern_with_unary( _may_generate_pattern_with_reshape( CallFunction( aten.addmm.default, KeywordArg("b"), x_pattern_with_reshape, t_pattern, ), KeywordArg("output_reshape_size"), input_dim_exceeds_two, ), aten.relu.default if relu_fused else None, ) dequant_linear_no_bias_pattern = generate_pattern_with_unary( _may_generate_pattern_with_reshape( CallFunction( aten.mm.default, x_pattern_with_reshape, t_pattern, ), KeywordArg("output_reshape_size"), input_dim_exceeds_two, ), aten.relu.default if relu_fused else None, ) return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( [torch.float32, torch.bfloat16], [True, False], [True, False] ) for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 # Case 3: int8-mixed-bf16, input dim size is 2 # Case 4: int8-mixed-bf16, input dim size exceeds 2 # quant # + - - - - | - - - - + # | dequant | # | | | # | OPT(to_bf16) | # | | | # | OPT(reshape) | # | / \ | # | node1 node2 | # + - - | - - - | - - + # OPT(reshape) OPT(reshape) # + - - | - - - | - - + # OPT(to_fp32) OPT(to_fp32) # + - - | - - - | - - + # quant quant _register_dequant_promotion_pass( _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern( is_tensor_overload=is_tensor_overload ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), KeywordArg("act_reshape_size"), with_reshape=input_dim_exceeds_two, ), pass_number=0, dtype=dtype, ) # pass_number=0 to run before weight prepack def _register_qconv_weight_prepack(): for dtype, with_dtype_convert in itertools.product( [torch.float32, torch.bfloat16], [True, False] ): if dtype == torch.float32 and with_dtype_convert: continue weight_prepack_patterns = _generate_qconv_weight_prepack_patterns( dtype, with_dtype_convert ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( weight_prepack_pattern, pass_number=1, dtype=dtype, with_dtype_convert=with_dtype_convert, ) def _register_qlinear_weight_prepack(): # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. # Then convert the pattern into a QLinear node with int8_fp32/bf16. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous # Case 3: int8-mixed-bf16, input dim size is 2 # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous # + - - - - | - - - - - - | - - - - - + # | dq_per_tensor dq_per_channel | # | | | | # | OPT(to_bf16) OPT(to_bf16) | # | | | | # | OPT(reshape) permute | # | \ / | # | addmm/mm | # | | | # | OPT(reshape) | # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous # + - - - - | - - - - - - | - - - - - + # | dq_per_tensor dq_per_channel | # | | | | # | OPT(to_bf16) OPT(to_bf16) | # | | | | # | expand permute | # | \ | | # | expand | # | / | # | bmm | # | | | # | OPT(add) | linear_weight_prepack_cases = itertools.product( [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) # Step 1: register patterns from mm and addmm for ( dtype, input_dim_exceeds_two, is_tensor_overload, with_dtype_convert, ) in linear_weight_prepack_cases: if dtype == torch.float32 and with_dtype_convert: continue weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( dtype, input_dim_exceeds_two, is_tensor_overload=is_tensor_overload, with_dtype_convert=with_dtype_convert, ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qlinear_weight_prepack_pass( weight_prepack_pattern, pass_number=1, dtype=dtype, input_dim_exceeds_two=input_dim_exceeds_two, with_dtype_convert=with_dtype_convert, ) # Step 2: register patterns from bmm # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous # refer to: # https://github.com/pytorch/pytorch/blob/80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 # in this case, we can convert it back to qlinear for ( dtype, with_bias, is_tensor_overload, with_dtype_convert, ) in itertools.product( [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ): if dtype == torch.float32 and with_dtype_convert: continue bmm_pattern = _generate_qlinear_weight_prepack_patterns( dtype=dtype, input_dim_exceeds_two=True, input_contiguous=False, with_bias=with_bias, is_tensor_overload=is_tensor_overload, with_dtype_convert=with_dtype_convert, ) _register_qlinear_weight_prepack_pass( bmm_pattern, pass_number=1 if with_bias else 2, # if with_bias, there is an output add, so we should try to match it firstly dtype=dtype, input_dim_exceeds_two=True, input_contiguous=False, with_dtype_convert=with_dtype_convert, ) def _register_linear_dynamic_fp16_weight_prepack_pass( pattern, pass_number, input_dim_exceeds_two=False, input_contiguous=True, relu_fused=False, ): def _extra_check_fn(match: Match): return match.kwargs["dtype_fp16"] == torch.float16 @register_freezing_graph_pattern( pattern, extra_check=_extra_check_fn, pass_number=pass_number, ) def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs): """ Match the pattern: fp32 activation | mm/addmm <- t <- to_fp32 <- to_fp16 <- weight | (reshape) <- (relu) OR fp32 activation | expand | bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight | (add) <- (relu) Insert weight prepack node and change the pattern to: fp32 activation | onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight (or onednn.linear_relu_dynamic_fp16) """ # find params x = kwargs["x"] w = kwargs["w"] bias = kwargs["b"] if "b" in kwargs else None # find linear node nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default] linear_nodes = [] for node in nodes_to_find: linear_nodes.extend(filter_nodes(match.nodes, node)) assert len(linear_nodes) == 1 linear_node = linear_nodes[0] assert isinstance(linear_node, torch.fx.node.Node) input_index = 1 if linear_node.target is aten.addmm.default else 0 weight_index = input_index + 1 # find relu node relu_node = None if relu_fused: relu_node = match.output_node() assert isinstance(relu_node, torch.fx.node.Node) # find reshape node, expand node and add node ( act_reshape_node, output_reshape_node, expand_x_node, expand_w_node, add_bias_node, ) = (None, None, None, None, None) t_node = None if input_dim_exceeds_two: if input_contiguous: act_reshape_node = linear_node.args[input_index] t_node = linear_node.args[weight_index] output_reshape_node = next(iter(linear_node.users)) assert output_reshape_node.target is aten.reshape.default else: expand_x_node = linear_node.args[input_index] expand_w_node = linear_node.args[weight_index] assert isinstance(expand_w_node, torch.fx.node.Node) t_node = expand_w_node.args[0] if bias: add_bias_node = next(iter(linear_node.users)) assert add_bias_node.target is aten.add.Tensor else: t_node = linear_node.args[weight_index] assert isinstance(t_node, torch.fx.node.Node) w_to_fp32_node = t_node.args[0] assert ( isinstance(w_to_fp32_node, torch.fx.node.Node) and w_to_fp32_node.target is quantized_decomposed.convert_element_type.no_fuse ) w_to_fp16_node = w_to_fp32_node.args[0] assert ( isinstance(w_to_fp16_node, torch.fx.node.Node) and w_to_fp16_node.target is quantized_decomposed.convert_element_type.no_fuse ) x_shape = x.meta.get("tensor_meta").shape if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None graph = match.graph with graph.inserting_before(linear_node): # Insert weight prepack node and the qlinear node packed_weight_inputs = ( w, x_shape, ) packed_weight_op = torch.ops.onednn.linear_prepack_fp16 prepack_weight_node = graph.call_function( packed_weight_op, args=packed_weight_inputs ) # create new linear node and insert on graph new_args: tuple[Any, ...] = ( x, prepack_weight_node, bias, ) linear_op = ( torch.ops.onednn.linear_relu_dynamic_fp16.default if relu_fused else torch.ops.onednn.linear_dynamic_fp16.default ) new_linear_node = graph.call_function(linear_op, args=new_args) out_node = match.output_node() out_node.replace_all_uses_with(new_linear_node) # Erase the original nodes in the reverse order new_linear_node.meta.update(out_node.meta) if relu_node is not None: graph.erase_node(relu_node) if output_reshape_node is not None: graph.erase_node(output_reshape_node) if add_bias_node is not None: graph.erase_node(add_bias_node) graph.erase_node(linear_node) if act_reshape_node is not None: assert isinstance(act_reshape_node, torch.fx.node.Node) graph.erase_node(act_reshape_node) if expand_x_node is not None: assert isinstance(expand_x_node, torch.fx.node.Node) graph.erase_node(expand_x_node) if expand_w_node is not None: assert isinstance(expand_w_node, torch.fx.node.Node) graph.erase_node(expand_w_node) graph.erase_node(t_node) graph.erase_node(w_to_fp32_node) graph.erase_node(w_to_fp16_node) counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( match.nodes ) def _register_linear_dynamic_fp16_weight_prepack(): to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse weight_pattern = CallFunction( to_dtype_op, CallFunction( to_dtype_op, KeywordArg("w"), KeywordArg("dtype_fp16"), ), KeywordArg("dtype_fp32"), ) cases = itertools.product( [False, True], # input_dim_exceeds_two [True, False], # input_contiguous [False, True], # relu fused ) for input_dim_exceeds_two, input_contiguous, relu_fused in cases: patterns = _generate_linear_dynamic_fp16_pattern( weight_pattern, input_dim_exceeds_two, input_contiguous, relu_fused, ) for pattern in patterns: _register_linear_dynamic_fp16_weight_prepack_pass( pattern, pass_number=0 if relu_fused else 1, input_dim_exceeds_two=input_dim_exceeds_two, input_contiguous=input_contiguous, relu_fused=relu_fused, ) def _register_smooth_quant_int_mm_pattern(): """ The pattern is: (no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape or (with bias) pattern_no_bias -> add (-> reshape -> reshape) """ # When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist # When torch.compile'ing with dynamic=False, they don't exist def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True): return CallFunction( aten.mul.Tensor, CallFunction( aten.mul.Tensor, CallFunction( prims.convert_element_type.default, CallFunction( aten._int_mm.default, CallFunction( aten.reshape.default, KeywordArg("a"), KeywordArg("in_shape"), ) if reshape_a else KeywordArg("a"), KeywordArg("b"), ), KeywordArg("dtype"), ), ( CallFunction( aten.expand.default, KeywordArg("x_scale"), Arg(), ) if expand_a_scale else KeywordArg("x_scale") ), ), KeywordArg("w_scale"), ) def _with_outer_reshape(pattern): return CallFunction( aten.reshape.default, pattern, KeywordArg("out_shape_no_bias") ) # for torch.compile(dynamic=False) pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False)) pattern_with_bias_1 = CallFunction( aten.add.Tensor, pattern_no_bias_1, KeywordArg("bias"), ) # for torch.compile(dynamic=True) pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True)) pattern_with_bias_2 = CallFunction( aten.reshape.default, CallFunction( aten.reshape.default, CallFunction( aten.add.Tensor, pattern_no_bias_2, KeywordArg("bias"), ), Arg(), ), KeywordArg("out_shape_with_bias"), ) # The following patterns are for torchao int8_dynamic_activation_int8_weight linear, # when both activation and weights are symmetrically quantized. # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used. # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias. # Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op. pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias( expand_a_scale=False, reshape_a=False ) pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias( expand_a_scale=True, reshape_a=False ) def _validate_pattern(match: Match): if len(match.nodes) not in [4, 5, 6, 7, 10]: return False # Make sure weight is a constant aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0] if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node): return False if aten_int_mm_node.args[1].op != "get_attr": return False if len(match.nodes) == 10: # Check the two tailing reshape nodes can be fused if match.nodes[9].args[1] != match.nodes[6].args[1]: return False if len(match.nodes) == 10 or ( len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor ): bias_idx = 7 if len(match.nodes) == 10 else 6 # Check bias shape bias_node = match.nodes[bias_idx].args[1] if not isinstance(bias_node, torch.fx.node.Node): return False if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr] return False return True pattern_to_pass_number = { pattern_no_bias_2: 0, pattern_with_bias_2: 0, pattern_no_bias_1: 1, pattern_with_bias_1: 1, pattern1_with_no_outer_or_act_reshape: 2, pattern2_with_no_outer_or_act_reshape: 2, } for pattern, pass_number in pattern_to_pass_number.items(): @register_freezing_graph_pattern( pattern, extra_check=_validate_pattern, pass_number=pass_number, ) def _int_mm_weight_prepack(match: Match, *args, **kwargs): bias = kwargs.get("bias", None) x = kwargs["a"] weight = kwargs["b"] dtype = kwargs["dtype"] x_scale = kwargs["x_scale"] w_scale = kwargs["w_scale"] x_shape = x.meta.get("tensor_meta").shape if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None out_node = match.output_node() with match.graph.inserting_before(out_node): transpose_node = match.graph.call_function( aten.permute.default, args=(weight, [1, 0]) ) contig_node = match.graph.call_function( aten.contiguous.default, args=(transpose_node,) ) packed_weight_inputs = ( contig_node, x_shape, ) packed_weight_op = torch.ops.onednn.qlinear_prepack prepack_weight_node = match.graph.call_function( packed_weight_op, args=packed_weight_inputs ) dummy_zp = None w_scale = match.graph.call_function( prims.convert_element_type.default, args=(w_scale, torch.float32) ) x_scale_shape = x_scale.meta.get("tensor_meta").shape x_scale_is_scalar = False if not has_free_symbols(x_scale_shape): prod = 1 for d in x_scale_shape: prod *= d x_scale_is_scalar = prod == 1 new_args: tuple[Any, ...] if x_scale_is_scalar: # in this case, we can call onednn.qlinear directly new_args = ( x, x_scale, dummy_zp, # x_zp prepack_weight_node, w_scale, dummy_zp, # w_zp bias, 1.0, # output_scale 0, # output_zero_point dtype, # output_dtype "none", # post op name [], # post op args "", # post op algorithm ) new_linear_node = match.graph.call_function( torch.ops.onednn.qlinear_pointwise.tensor, args=new_args ) out_node.replace_all_uses_with(new_linear_node) new_linear_node.meta.update(out_node.meta) else: # onednn.qlinear does not support per-channel quantization of x # so in this case, we have to apply x scale and add bias ourselves after qlinear in_shape = kwargs.get("in_shape", None) if in_shape is None: x_reshaped = x else: x_reshaped = match.graph.call_function( aten.reshape.default, args=(x, in_shape) ) new_args = ( x_reshaped, 1.0, # x_scale 0, # x_zp prepack_weight_node, w_scale, dummy_zp, # w_zp None, # bias 1.0, # output_scale 0, # output_zero_point dtype, # output_dtype "none", # post op name [], # post op args "", # post op algorithm ) new_linear_node = match.graph.call_function( torch.ops.onednn.qlinear_pointwise, args=new_args ) # apply x scale new_out_node = match.graph.call_function( aten.mul.Tensor, args=(new_linear_node, x_scale) ) # Add bias and reshape has_outer_reshape = ( kwargs.get("out_shape_with_bias", None) is not None or kwargs.get("out_shape_no_bias", None) is not None ) if has_outer_reshape: out_shape = kwargs.get( "out_shape_with_bias", kwargs["out_shape_no_bias"] ) if bias is not None: new_out_node = match.graph.call_function( aten.add.Tensor, args=(new_out_node, bias) ) if has_outer_reshape: new_out_node = match.graph.call_function( aten.reshape.default, args=(new_out_node, out_shape), # type: ignore[possibly-undefined] ) else: if has_outer_reshape: new_out_node = match.graph.call_function( aten.reshape.default, args=(new_out_node, out_shape), # type: ignore[possibly-undefined] ) out_node.replace_all_uses_with(new_out_node) new_out_node.meta.update(out_node.meta) for node in reversed(match.nodes): match.graph.erase_node(node) counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( match.nodes ) class PostOpAttr: def __init__( self, binary_op_name: str = "none", alpha=None, unary_op_name: str = "none", scalars_attr=None, algorithm_attr=None, ) -> None: self.binary_op_name = binary_op_name self.alpha = alpha if alpha else 1.0 self.unary_op_name = unary_op_name self.scalars_attr = scalars_attr if scalars_attr else [] self.algorithm_attr = algorithm_attr if algorithm_attr else "" def _register_qconv_post_op_fusion_pass( pattern, pass_number, computation_op, post_op_attr, ): has_binary_post_op = post_op_attr.binary_op_name != "none" @register_freezing_graph_pattern( pattern, extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op), pass_number=pass_number, ) def qconv(match: Match, *args, **kwargs): # Activation QParams x, x_scale, x_zp = ( kwargs["x"], kwargs["x_scale"], kwargs["x_zp"], ) # Weight QParams packed_weight, w_scale, w_zp = ( kwargs["packed_weight"], kwargs["w_scale"], kwargs["w_zp"], ) # Conv Params b, stride, padding, dilation, groups = ( kwargs["b"], kwargs["stride"], kwargs["padding"], kwargs["dilation"], kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] # Output QParams o_inv_scale = ( kwargs["o_inv_scale"] if (output_dtype == torch.uint8 or output_dtype == torch.int8) else 1.0 ) o_zero_point = ( kwargs["o_zp"] if (output_dtype == torch.uint8 or output_dtype == torch.int8) else 0 ) assert ( kwargs["postop_name"] == "none" ) # Expected no post op fused in weight prepack phase if post_op_attr.unary_op_name == "hardtanh": min_value = kwargs.get("min_value") max_value = kwargs.get("max_value") post_op_attr.scalars_attr = [min_value, max_value] out_node = match.output_node() with match.graph.inserting_before(out_node): if not has_binary_post_op: computation_args: tuple[Any, ...] = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, b, stride, padding, dilation, groups, o_inv_scale, o_zero_point, output_dtype, post_op_attr.unary_op_name, post_op_attr.scalars_attr, post_op_attr.algorithm_attr, ) else: accum = ( kwargs["accum"] if output_dtype in [torch.uint8, torch.int8] else kwargs["accum_after_dequant"] ) accum_scale = ( kwargs["accum_scale"] if output_dtype in [torch.uint8, torch.int8] else 1.0 ) accum_zp = ( kwargs["accum_zp"] if output_dtype in [torch.uint8, torch.int8] else 0 ) computation_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, accum, b, stride, padding, dilation, groups, o_inv_scale, o_zero_point, output_dtype, accum_scale, accum_zp, post_op_attr.binary_op_name, post_op_attr.alpha, post_op_attr.unary_op_name, post_op_attr.scalars_attr, post_op_attr.algorithm_attr, ) new_conv_node = match.graph.call_function( computation_op, args=computation_args ) out_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(out_node.meta) for node in reversed(match.nodes): match.graph.erase_node(node) count_key = ( "qconv2d_binary_matcher_count" if has_binary_post_op else "qconv_unary_matcher_count" ) nodes_key = ( "qconv2d_binary_matcher_nodes" if has_binary_post_op else "qconv_unary_matcher_nodes" ) counters["inductor"][count_key] += 1 counters["inductor"][nodes_key] += len(match.nodes) return qconv def _register_qconv_unary_fusion(): from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: # Priority 1 to match: QConv2d Unary pattern with int8 output # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant is_bf16 = original_pattern_output_dtype == torch.bfloat16 conv_unary_replace_patterns = { PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( get_qconv_pt2e_pattern(1), ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( get_qconv_pt2e_pattern(1), aten.relu.default ), ), PostOpAttr( "none", None, "hardtanh", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, get_qconv_pt2e_pattern(1), 1, is_bf16, ), with_dtype_convert=is_bf16, ), PostOpAttr( "none", None, "hardswish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, ), PostOpAttr( "none", None, "swish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, ), } for unary_attr, patterns in conv_unary_replace_patterns.items(): # Register qconv2d pattern for ExternKernel Lowering _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( get_qconv_pt2e_pattern(1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, get_qconv_pt2e_pattern(1), 1, is_bf16, ), Arg(), is_bf16, ), PostOpAttr( "none", None, "hardswish", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), Arg(), is_bf16, ), PostOpAttr( "none", None, "swish", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), Arg(), is_bf16, ), } for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): # Register qconv2d pattern for ExternKernel Lowering _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) def _register_qconv_binary_fusion(): for int8_mixed_bf16_with_inplace_add in [False, True]: # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} for swap_inputs in swap_binary_inputs_list: binary_replace_patterns.update( { PostOpAttr( "sum", 1.0, "none", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), ), PostOpAttr( "sum", 1.0, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), aten.relu.default, ), ), } ) for binary_unary_attr, patterns in binary_replace_patterns.items(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output binary_replace_float_out_patterns = {} for swap_inputs in swap_binary_inputs_list: binary_replace_float_out_patterns.update( { PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), aten.relu.default, ) } ) for ( binary_unary_attr, patterns, ) in binary_replace_float_out_patterns.items(): if int8_mixed_bf16_with_inplace_add: _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) else: _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output binary_replace_float_out_patterns = {} for swap_inputs in swap_binary_inputs_list: binary_replace_float_out_patterns.update( { PostOpAttr( "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), } ) for ( binary_unary_attr, patterns, ) in binary_replace_float_out_patterns.items(): _register_qconv_post_op_fusion_pass( patterns, 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op binary_unary_attr, # binary_unary_attr ) def _register_qlinear_post_op_fusion_pass( pattern, pass_number, computation_op, post_op_attr, ): has_binary_post_op = post_op_attr.binary_op_name != "none" @register_freezing_graph_pattern( pattern, extra_check=_is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op), pass_number=pass_number, ) def qlinear_post_op_fusion(match: Match, *args, **kwargs): """ Match the pattern: qlinear - post op """ output_dtype = _get_pattern_output_dtype(match) # Activation QParams x, x_scale, x_zp = ( kwargs["x"], kwargs["x_scale"], kwargs["x_zp"], ) # Weight QParams packed_weight, w_scale, w_zp = ( kwargs["packed_weight"], kwargs["w_scale"], kwargs["w_zp"], ) # bias b = kwargs["b"] if "b" in kwargs else None # Output QParams o_inv_scale = ( kwargs["o_inv_scale"] if (output_dtype in [torch.uint8, torch.int8]) else 1.0 ) o_zero_point = ( kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0 ) assert ( kwargs["postop_name"] == "none" ) # Expected no post op fused in weight prepack phase out_node = match.output_node() with match.graph.inserting_before(out_node): if not has_binary_post_op: computation_args: tuple[Any, ...] = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, b, o_inv_scale, o_zero_point, output_dtype, post_op_attr.unary_op_name, post_op_attr.scalars_attr, post_op_attr.algorithm_attr, ) else: other = kwargs["other"] if "other" in kwargs else kwargs["accum"] x2_scale = 1.0 x2_zp = 0 computation_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, other, b, o_inv_scale, o_zero_point, output_dtype, x2_scale, x2_zp, post_op_attr.binary_op_name, post_op_attr.alpha, post_op_attr.unary_op_name, post_op_attr.scalars_attr, post_op_attr.algorithm_attr, ) new_linear_node = match.graph.call_function( computation_op, args=computation_args ) out_node.replace_all_uses_with(new_linear_node) new_linear_node.meta.update(out_node.meta) for node in reversed(match.nodes): match.graph.erase_node(node) count_key = ( "qlinear_binary_matcher_count" if has_binary_post_op else "qlinear_unary_matcher_count" ) nodes_key = ( "qlinear_binary_matcher_nodes" if has_binary_post_op else "qlinear_unary_matcher_nodes" ) counters["inductor"][count_key] += 1 counters["inductor"][nodes_key] += len(match.nodes) def _register_qlinear_unary_fusion(): from .mkldnn_fusion import ( _gelu_fusion_1 as _gelu_fusion_erf, _gelu_fusion_2 as _gelu_fusion_tanh, ) for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: is_bf16 = original_pattern_output_dtype == torch.bfloat16 for x_scale_zp_are_tensors in (False, True): qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) computation_op = ( torch.ops.onednn.qlinear_pointwise.tensor if x_scale_zp_are_tensors else torch.ops.onednn.qlinear_pointwise.default ) # Priority 1 to match: QLinear Unary pattern with int8 output linear_unary_replace_patterns = { PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( qlinear_pattern, ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary(qlinear_pattern, aten.relu.default), ), PostOpAttr( "none", None, "gelu", [], "none" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _gelu_fusion_erf, get_qlinear_pt2e_pattern( x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), 2, is_bf16, ), with_dtype_convert=is_bf16, ), PostOpAttr( "none", None, "gelu", [], "tanh" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _gelu_fusion_tanh, get_qlinear_pt2e_pattern( x_scale_zp_are_tensors, 1 if is_bf16 else 4 ), 4, is_bf16, ), with_dtype_convert=is_bf16, ), } for unary_attr, patterns in linear_unary_replace_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 3, # pass_number computation_op, unary_attr, # unary_attr ) # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output linear_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( qlinear_pattern, aten.relu.default ), PostOpAttr( "none", None, "gelu", [], "none" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _gelu_fusion_erf, get_qlinear_pt2e_pattern( x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), 2, is_bf16, ), Arg(), is_bf16, ), PostOpAttr( "none", None, "gelu", [], "tanh" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _gelu_fusion_tanh, get_qlinear_pt2e_pattern( x_scale_zp_are_tensors, 1 if is_bf16 else 4 ), 4, is_bf16, ), Arg(), is_bf16, ), } for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 4, # pass_number computation_op, unary_attr, # unary_attr ) def _register_qlinear_binary_fusion(): r""" Supported linear-binary(-unary) patterns linear(X) extra input \ / Add | Optional(relu) | Y 1. int8-mixed-fp32 +---+---------------+-----------+------------------------------+---------+ | # | Add type | Quant out | Pattern | Post op | +---+---------------+-----------+------------------------------+---------+ | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | +---+---------------+-----------+------------------------------+---------+ | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | +---+---------------+-----------+------------------------------+---------+ 2. int8-mixed-bf16 +---+----------+---------------+-----------+-----------------------------------------+---------+ | # | X2 dtype | Add type | Quant out | Pattern | Post op | +---+----------+---------------+-----------+-----------------------------------------+---------+ | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | +---+----------+---------------+-----------+-----------------------------------------+---------+ | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | +---+----------+---------------+-----------+-----------------------------------------+---------+ | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | | | | In-place right| | | | +---+----------+---------------+-----------+-----------------------------------------+---------+ | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | | | | In-place right| | | | +---+----------+---------------+-----------+-----------------------------------------+---------+ | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | +---+----------+---------------+-----------+-----------------------------------------+---------+ | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | +---+----------+---------------+-----------+-----------------------------------------+---------+ Note (1) The positions of linear and the extra input can be swapped. (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. """ for x_scale_zp_are_tensors in (False, True): qlinear_binary_op = ( torch.ops.onednn.qlinear_pointwise.binary_tensor if x_scale_zp_are_tensors else torch.ops.onednn.qlinear_pointwise.binary ) unary_postop_list = ["none", "relu"] unary_postop_dict = { "none": None, "relu": aten.relu.default, } convert_dtype_after_binary_list = [False, True] # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, # totally 3 patterns (2 are identical) swap_binary_inputs_list = [False, True] int8_mixed_bf16_list = [False, True] combinations = itertools.product( unary_postop_list, int8_mixed_bf16_list, swap_binary_inputs_list, convert_dtype_after_binary_list, ) qlinear_binary_replace_patterns = {} for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: if not int8_mixed_bf16 and cvt_dtype_binary: # No convert node after binary node if dtypes are all fp32 continue qlinear_binary_replace_patterns.update( { PostOpAttr( "add", 1.0, unary_op, [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), KeywordArg("other"), # If fp32 extra input is inplace added to bf16 linear output, # a to_bf16 node is inserted after binary dtype_convert=cvt_dtype_binary, swap_inputs=swap_inputs, ), unary_postop_dict[unary_op], ), ) } ) for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 3, # pass_number qlinear_binary_op, # computation_op binary_unary_attr, ) # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, # totally 2 patterns (2 are identical) binary_replace_float_out_patterns = {} for swap_binary_inputs in swap_binary_inputs_list: binary_replace_float_out_patterns.update( { PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), KeywordArg("accum"), dtype_convert=False, swap_inputs=swap_binary_inputs, ), aten.relu.default, ), } ) for ( binary_unary_attr, patterns, ) in binary_replace_float_out_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 4, # pass_number qlinear_binary_op, # computation_op binary_unary_attr, ) # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output # Covers case (6) of int8-mixed-bf16 binary_replace_float_out_patterns = {} for swap_binary_inputs in swap_binary_inputs_list: binary_replace_float_out_patterns.update( { PostOpAttr("add", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), KeywordArg("other"), dtype_convert=True, swap_inputs=swap_binary_inputs, ), aten.relu.default, ), } ) for ( binary_unary_attr, patterns, ) in binary_replace_float_out_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 4, # pass_number qlinear_binary_op, # computation_op binary_unary_attr, ) # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, # totally 2 patterns (2 are identical) binary_replace_float_out_patterns = {} for swap_binary_inputs in swap_binary_inputs_list: binary_replace_float_out_patterns.update( { PostOpAttr( "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), KeywordArg("accum"), dtype_convert=False, swap_inputs=swap_binary_inputs, ), } ) for ( binary_unary_attr, patterns, ) in binary_replace_float_out_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 5, # pass_number qlinear_binary_op, # computation_op binary_unary_attr, ) # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output # Covers (6) of int8-mixed-bf16 binary_replace_float_out_patterns = {} for swap_binary_inputs in swap_binary_inputs_list: binary_replace_float_out_patterns.update( { PostOpAttr( "add", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), KeywordArg("other"), dtype_convert=True, swap_inputs=swap_binary_inputs, ), } ) for ( binary_unary_attr, patterns, ) in binary_replace_float_out_patterns.items(): _register_qlinear_post_op_fusion_pass( patterns, 5, # pass_number qlinear_binary_op, # computation_op binary_unary_attr, ) @functools.cache def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 _register_dequant_promotion() # Step 2: QConv weight prepack _register_qconv_weight_prepack() # Step 3: QLinear weight prepack _register_qlinear_weight_prepack() _register_linear_dynamic_fp16_weight_prepack() # Step 4: weight prepack for SmoothQuant from Torchao _register_smooth_quant_int_mm_pattern() # Step 5: QLinear post op Fusion if not torch.ops.mkldnn._is_mkldnn_acl_supported(): # skip fusion on ARM _register_qconv_unary_fusion() _register_qconv_binary_fusion() _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() def _is_valid_concat_linear_woq_int4_fusion(computation_nodes): computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default act = computation_nodes[0].args[0] wgt = computation_nodes[0].args[1] in_feature_size = wgt.meta.get("val").size(1) # type: ignore[union-attr] group_size = computation_nodes[0].args[2] return len(computation_nodes) >= 2 and all( ( node.target == computation_op and node.args[0] == act # share same activation and ( node.args[1].meta.get("val").size(1) == in_feature_size ) # same in feature size and (node.args[1] != wgt or gemm_idx == 0) and node.args[1].op == "get_attr" # wgt are all constants and node.args[2] == group_size # same group size ) for gemm_idx, node in enumerate(computation_nodes) ) def concat_linear_woq_int4(gm: torch.fx.GraphModule): """ Concat Linear optimization pass for WOQ int4 This pass fuses the original pattern: def ... return (woq_int4(x, w1, group_size, scale_zp1), woq_int4(x, w2, group_size, scale_zp1) ...) into a single operation: def ... concat_res = woq_int4(x, concat_w, group_size, concat_scale_zp) return split(concat_res, split_size_list) """ def concat_wgt(packed_wgts, scale_zps, group_size, act_dtype): # Concat the wgts and scale_zps, and repack the wgt unpacked_wgts = [] for packed_wgt in packed_wgts: # Get the unpacked weight list # Same as https://github.com/pytorch/pytorch/pull/156174 K = packed_wgt.size(1) * 2 N = packed_wgt.size(0) x = torch.eye(K).to(dtype=act_dtype) qscales_and_zeros = ( torch.tensor([1.0, 8.0]) .to(dtype=act_dtype) .expand(K // group_size, N, 2) .contiguous() ) unpacked_wgts.append( torch.ops.aten._weight_int4pack_mm_for_cpu( x, packed_wgt, group_size, qscales_and_zeros, ) .t() .contiguous() .to(torch.int32) # N, K ) concat_unpacked_wgt = torch.cat(unpacked_wgts, dim=0) repack_w = torch.ops.aten._convert_weight_to_int4pack_for_cpu( concat_unpacked_wgt, 1 ) concat_scale_zp = torch.cat(scale_zps, dim=1).contiguous() return repack_w, concat_scale_zp graph = gm.graph computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default for node in graph.find_nodes(op="call_function", target=computation_op): if ( not node._erased and isinstance(node.meta.get("val"), torch.Tensor) and node.meta["val"].device.type == "cpu" ): act = node.args[0] users = list(act.users) if _is_valid_concat_linear_woq_int4_fusion(users): with graph.inserting_before(node): assert all(user.args[1].op == "get_attr" for user in users) computation_node_0 = users[0] packed_wgts = [getattr(gm, user.args[1].target) for user in users] group_size = computation_node_0.args[2] scale_zps = [getattr(gm, user.args[3].target) for user in users] out_feature_size_list = [ packed_wgt.size(0) for packed_wgt in packed_wgts ] repack_w, concat_scale_zp = concat_wgt( packed_wgts, scale_zps, group_size, act.meta.get("val").dtype ) repack_w_node_name = computation_node_0.args[1].target + "_concat" concat_scale_zp_node_name = ( computation_node_0.args[3].target + "_concat" ) gm.register_buffer(repack_w_node_name, repack_w) setattr(gm, repack_w_node_name, repack_w) gm.register_buffer(concat_scale_zp_node_name, concat_scale_zp) setattr(gm, concat_scale_zp_node_name, concat_scale_zp) repack_w_node = graph.create_node( "get_attr", repack_w_node_name, (), {} ) with graph.inserting_after(repack_w_node): concat_scale_zp_node = graph.create_node( "get_attr", concat_scale_zp_node_name, (), {} ) with graph.inserting_after(concat_scale_zp_node): concat_int4_gemm_node = graph.create_node( "call_function", computation_op, ( act, repack_w_node, group_size, concat_scale_zp_node, ), ) with graph.inserting_after(concat_int4_gemm_node): split_node = graph.create_node( "call_function", torch.ops.aten.split_with_sizes.default, ( concat_int4_gemm_node, out_feature_size_list, 1, # split dim ), ) with graph.inserting_after(split_node): for gemm_idx, user in enumerate(users): assert user.target == computation_op get_item = graph.create_node( "call_function", operator.getitem, ( split_node, gemm_idx, ), ) with graph.inserting_after(get_item): clone_node = graph.create_node( "call_function", torch.ops.aten.clone.default, (get_item,), {"memory_format": torch.contiguous_format}, ) user.replace_all_uses_with(clone_node) graph.erase_node(user) def quant_lift_up(graph_module: torch.fx.GraphModule): """ Lift up the quant node before view like nodes. It can benefit performance of Attention like block. For example, we have the pattern as: DQ DQ LINEAR LINEAR VIEW VIEW PERMUTE PERMUTE TRANSPOSE Q Q DQ DQ Matmul DIV ADD SOFTMAX We want to lift up the the quant nodes from matmul before view like nodes as the output of Linear node. DQ DQ LINEAR LINEAR Q Q VIEW VIEW PERMUTE PERMUTE TRANSPOSE DQ DQ Matmul DIV ADD SOFTMAX It produces a DQ->LINEAR->Q pattern which can be fused by backend. """ def is_view_op(node): return node.op == "call_function" and node.target in _VIEW_OPS for node in graph_module.graph.nodes: # Leslie: Here we verify that the quant node has exactly # one input FX node, with constant scalar value for scale and zero point. # For the case input of quant node has more than one input FX nodes, # extend the implementation to lift up all the connected nodes # before the view nodes to keep the topological order. if ( node.op == "call_function" and node.target in _PER_TENSOR_QUANTIZE_OPS and len(node.all_input_nodes) == 1 and is_view_op(node.all_input_nodes[0]) ): quant_node = node input_node_of_quant = quant_node.args[0] # Check the nodes along lift up path has only 1 user node # Propagate view like node to find where to insert the new quant node could_lift_up = True current_node = quant_node input_node = current_node.args[0] while is_view_op(input_node): if len(input_node.users) != 1: could_lift_up = False break current_node = input_node input_node = current_node.args[0] # Further check the input node of the first view node has only 1 user node if could_lift_up and len(input_node.users) == 1: # Replace dequant's input from quant to quant's input quant_node.replace_all_uses_with(input_node_of_quant) # Insert the new quant node with graph_module.graph.inserting_before(current_node): new_quant_node = graph_module.graph.node_copy(quant_node) input_node.replace_all_uses_with(new_quant_node) # Update inputs of new_quant_node def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: if n == input_node_of_quant: return input_node else: return n new_args = map_arg(new_quant_node.args, maybe_replace_node) new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) new_quant_node.args = new_args # type: ignore[assignment] new_quant_node.kwargs = new_kwargs # type: ignore[assignment] graph_module.graph.erase_node(quant_node) graph_module.graph.lint() graph_module.recompile()