import logging from dataclasses import dataclass from typing import Optional, Union import torch log = logging.getLogger(__name__) @dataclass(frozen=True) class DeviceInfo: """ Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not, then the higher number is reported. Sparsity is not considered. Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace. For example, """ tops: dict[Union[torch.dtype, str], float] dram_bw_gbs: float dram_gb: float # Indexing is based on `torch.cuda.get_device_name()` # TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on. _device_mapping: dict[str, DeviceInfo] = { # Source: # @lint-ignore https://www.nvidia.com/en-us/data-center/h100/ "NVIDIA H100": DeviceInfo( tops={ torch.float64: 67.0, torch.float32: 67.5, "torch.tf32": 156.0, torch.bfloat16: 1979.0, torch.float16: 1979.0, torch.float8_e8m0fnu: 3958.0, torch.float8_e8m0fnu: 3958.0, torch.float8_e4m3fnuz: 3958.0, torch.float8_e5m2: 3958.0, torch.float8_e5m2fnuz: 3958.0, torch.float8_e8m0fnu: 3958.0, torch.int8: 3958.0, }, dram_bw_gbs=3350, dram_gb=80, ), # Source: # @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/ # nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf "NVIDIA A100": DeviceInfo( tops={ torch.float64: 19.5, torch.float32: 19.5, torch.bfloat16: 312.5, torch.float16: 312.5, # Not in datasheet: float8 torch.int8: 624.0, "torch.tf32": 156.0, }, dram_bw_gbs=2039.0, dram_gb=80.0, ), # Source: # @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet "NVIDIA L4": DeviceInfo( tops={ # This is a guess, not in datasheet torch.float64: 15.1, torch.float32: 30.3, "torch.tf32": 120.0, torch.bfloat16: 242.0, torch.float16: 242.0, torch.float8_e8m0fnu: 485.0, torch.float8_e8m0fnu: 485.0, torch.float8_e4m3fnuz: 485.0, torch.float8_e5m2: 485.0, torch.float8_e5m2fnuz: 485.0, torch.float8_e8m0fnu: 485.0, torch.int8: 485.0, }, dram_bw_gbs=3350, dram_gb=24, ), # Source: # @lint-ignore https://www.amd.com/content/dam/amd/en/documents\ # /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf "AMD MI300A": DeviceInfo( tops={ torch.float64: 122.6, torch.float32: 122.6, "torch.tf32": 490.3, torch.bfloat16: 980.6, torch.float16: 980.6, torch.float8_e8m0fnu: 1961.2, torch.float8_e8m0fnu: 1961.2, torch.float8_e4m3fnuz: 1961.2, torch.float8_e5m2: 1961.2, torch.float8_e5m2fnuz: 1961.2, torch.float8_e8m0fnu: 1961.2, torch.int8: 1961.2, }, dram_bw_gbs=5300.0, dram_gb=128.0, ), # Source: # @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\ # instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf "AMD MI300X": DeviceInfo( tops={ torch.float64: 163.4, torch.float32: 163.4, "torch.tf32": 653.7, torch.bfloat16: 1307.4, torch.float16: 1307.4, torch.float8_e8m0fnu: 2614.9, torch.float8_e8m0fnu: 2614.9, torch.float8_e4m3fnuz: 2614.9, torch.float8_e5m2: 2614.9, torch.float8_e5m2fnuz: 2614.9, torch.float8_e8m0fnu: 2614.9, torch.int8: 2614.9, }, dram_bw_gbs=5300.0, dram_gb=192.0, ), # Source: # @lint-ignore https://www.amd.com/content/dam/amd/\ # en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf "AMD MI210X": DeviceInfo( tops={ torch.float64: 45.3, torch.float32: 45.3, # not specified, fall back to float32 numbers "torch.tf32": 45.3, torch.bfloat16: 181.0, torch.float16: 181.0, # not specified, fall back to float16 numbers torch.float8_e8m0fnu: 181.0, torch.float8_e8m0fnu: 181.0, torch.float8_e4m3fnuz: 181.0, torch.float8_e5m2: 181.0, torch.float8_e5m2fnuz: 181.0, torch.float8_e8m0fnu: 181.0, torch.int8: 181.0, }, # pcie4.0x16 dram_bw_gbs=1600.0, dram_gb=64.0, ), } _device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"] _device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"] def lookup_device_info(name: str) -> Optional[DeviceInfo]: """ Problem: when diffing profiles between amd and nvidia, we don't have access to the device information of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices. If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping. name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name(). """ return _device_mapping.get(name, None) def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]: """ Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device is not in the datasheet list above. """ name: Optional[str] = torch.cuda.get_device_name() if name is None: log.info("No device found, returning None") return None device_info = lookup_device_info(name) if device_info is None: log_str = f"Device {name} not in datasheet, returning None" log.info(log_str) return None if dtype not in device_info.tops: log.info( "Device %s does not have a datasheet entry for %s, returning None", name, dtype, ) return None return device_info.tops[ "torch.tf32" if dtype == torch.float32 and is_tf32 else dtype ]