from typing import Any, Optional import torch from torch import nn from torch.ao.quantization import QConfig __all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"] class QuantStub(nn.Module): r"""Quantize stub module, before calibration, this is same as an observer, it will be swapped as `nnq.Quantize` in `convert`. Args: qconfig: quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules """ def __init__(self, qconfig: Optional[QConfig] = None): super().__init__() if qconfig: self.qconfig = qconfig def forward(self, x: torch.Tensor) -> torch.Tensor: return x class DeQuantStub(nn.Module): r"""Dequantize stub module, before calibration, this is same as identity, this will be swapped as `nnq.DeQuantize` in `convert`. Args: qconfig: quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules """ def __init__(self, qconfig: Optional[Any] = None): super().__init__() if qconfig: self.qconfig = qconfig def forward(self, x: torch.Tensor) -> torch.Tensor: return x class QuantWrapper(nn.Module): r"""A wrapper class that wraps the input module, adds QuantStub and DeQuantStub and surround the call to module with call to quant and dequant modules. This is used by the `quantization` utility functions to add the quant and dequant modules, before `convert` function `QuantStub` will just be observer, it observes the input tensor, after `convert`, `QuantStub` will be swapped to `nnq.Quantize` which does actual quantization. Similarly for `DeQuantStub`. """ quant: QuantStub dequant: DeQuantStub module: nn.Module def __init__(self, module: nn.Module): super().__init__() qconfig = getattr(module, "qconfig", None) self.add_module("quant", QuantStub(qconfig)) self.add_module("dequant", DeQuantStub(qconfig)) self.add_module("module", module) self.train(module.training) def forward(self, X: torch.Tensor) -> torch.Tensor: X = self.quant(X) X = self.module(X) return self.dequant(X)