import json import os import uuid from abc import ABC, abstractmethod from typing import Dict, List, Optional import base64 import hashlib import functools import sysconfig from triton import __version__, knobs class CacheManager(ABC): def __init__(self, key, override=False, dump=False): pass @abstractmethod def get_file(self, filename) -> Optional[str]: pass @abstractmethod def put(self, data, filename, binary=True) -> str: pass @abstractmethod def get_group(self, filename: str) -> Optional[Dict[str, str]]: pass @abstractmethod def put_group(self, filename: str, group: Dict[str, str]): pass class FileCacheManager(CacheManager): def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None if dump: self.cache_dir = knobs.cache.dump_dir self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) elif override: self.cache_dir = knobs.cache.override_dir self.cache_dir = os.path.join(self.cache_dir, self.key) else: # create cache directory if it doesn't exist self.cache_dir = knobs.cache.dir if self.cache_dir: self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) else: raise RuntimeError("Could not create or locate cache dir") def _make_path(self, filename) -> str: return os.path.join(self.cache_dir, filename) def has_file(self, filename) -> bool: if not self.cache_dir: raise RuntimeError("Could not create or locate cache dir") return os.path.exists(self._make_path(filename)) def get_file(self, filename) -> Optional[str]: if self.has_file(filename): return self._make_path(filename) else: return None def get_group(self, filename: str) -> Optional[Dict[str, str]]: grp_filename = f"__grp__{filename}" if not self.has_file(grp_filename): return None grp_filepath = self._make_path(grp_filename) with open(grp_filepath) as f: grp_data = json.load(f) child_paths = grp_data.get("child_paths", None) # Invalid group data. if child_paths is None: return None result = {} for c, p in child_paths.items(): if os.path.exists(p): result[c] = p return result # Note a group of pushed files as being part of a group def put_group(self, filename: str, group: Dict[str, str]) -> str: if not self.cache_dir: raise RuntimeError("Could not create or locate cache dir") grp_contents = json.dumps({"child_paths": group}) grp_filename = f"__grp__{filename}" return self.put(grp_contents, grp_filename, binary=False) def put(self, data, filename, binary=True) -> str: if not self.cache_dir: raise RuntimeError("Could not create or locate cache dir") binary = isinstance(data, bytes) if not binary: data = str(data) assert self.lock_path is not None filepath = self._make_path(filename) # Random ID to avoid any collisions rnd_id = str(uuid.uuid4()) # we use the PID in case a bunch of these around so we can see what PID made it pid = os.getpid() # use temp dir to be robust against program interruptions temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") os.makedirs(temp_dir, exist_ok=True) temp_path = os.path.join(temp_dir, filename) mode = "wb" if binary else "w" with open(temp_path, mode) as f: f.write(data) # Replace is guaranteed to be atomic on POSIX systems if it succeeds # so filepath cannot see a partial write os.replace(temp_path, filepath) os.removedirs(temp_dir) return filepath class RemoteCacheBackend: """ A backend implementation for accessing a remote/distributed cache. """ def __init__(self, key: str): pass @abstractmethod def get(self, filenames: List[str]) -> Dict[str, bytes]: pass @abstractmethod def put(self, filename: str, data: bytes): pass class RedisRemoteCacheBackend(RemoteCacheBackend): def __init__(self, key): import redis self._key = key self._key_fmt = knobs.cache.redis.key_format self._redis = redis.Redis( host=knobs.cache.redis.host, port=knobs.cache.redis.port, ) def _get_key(self, filename: str) -> str: return self._key_fmt.format(key=self._key, filename=filename) def get(self, filenames: List[str]) -> Dict[str, str]: results = self._redis.mget([self._get_key(f) for f in filenames]) return {filename: result for filename, result in zip(filenames, results) if result is not None} def put(self, filename: str, data: bytes) -> Dict[str, bytes]: self._redis.set(self._get_key(filename), data) class RemoteCacheManager(CacheManager): def __init__(self, key, override=False, dump=False): # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. remote_cache_cls = knobs.cache.remote_manager_class if not remote_cache_cls: raise RuntimeError( "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class") self._backend = remote_cache_cls(key) self._override = override self._dump = dump # Use a `FileCacheManager` to materialize remote cache paths locally. self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) def _materialize(self, filename: str, data: bytes): # We use a backing `FileCacheManager` to provide the materialized data. return self._file_cache_manager.put(data, filename, binary=True) def get_file(self, filename: str) -> Optional[str]: # We don't handle the dump/override cases. if self._dump or self._override: return self._file_cache_manager.get_file(filename) # We always check the remote cache backend -- even if our internal file- # based cache has the item -- to make sure LRU accounting works as # expected. results = self._backend.get([filename]) if len(results) == 0: return None (_, data), = results.items() return self._materialize(filename, data) def put(self, data, filename: str, binary=True) -> str: # We don't handle the dump/override cases. if self._dump or self._override: return self._file_cache_manager.put(data, filename, binary=binary) if not isinstance(data, bytes): data = str(data).encode("utf-8") self._backend.put(filename, data) return self._materialize(filename, data) def get_group(self, filename: str) -> Optional[Dict[str, str]]: # We don't handle the dump/override cases. if self._dump or self._override: return self._file_cache_manager.get_group(filename) grp_filename = f"__grp__{filename}" grp_filepath = self.get_file(grp_filename) if grp_filepath is None: return None with open(grp_filepath) as f: grp_data = json.load(f) child_paths = grp_data.get("child_paths", None) result = None # Found group data. if child_paths is not None: result = {} for child_path, data in self._backend.get(child_paths).items(): result[child_path] = self._materialize(child_path, data) return result def put_group(self, filename: str, group: Dict[str, str]): # We don't handle the dump/override cases. if self._dump or self._override: return self._file_cache_manager.put_group(filename, group) grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) grp_filename = f"__grp__{filename}" return self.put(grp_contents, grp_filename) def _base32(key): # Assume key is a hex string. return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") def get_cache_manager(key) -> CacheManager: cls = knobs.cache.manager_class or FileCacheManager return cls(_base32(key)) def get_override_manager(key) -> CacheManager: cls = knobs.cache.manager_class or FileCacheManager return cls(_base32(key), override=True) def get_dump_manager(key) -> CacheManager: cls = knobs.cache.manager_class or FileCacheManager return cls(_base32(key), dump=True) def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): # Get unique key for the compiled code signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" for kw in kwargs: key = f"{key}-{kwargs.get(kw)}" key = hashlib.sha256(key.encode("utf-8")).hexdigest() return _base32(key) @functools.lru_cache() def triton_key(): import pkgutil TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) contents = [] # frontend with open(__file__, "rb") as f: contents += [hashlib.sha256(f.read()).hexdigest()] # compiler path_prefixes = [ (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), (os.path.join(TRITON_PATH, "backends"), "triton.backends."), ] for path, prefix in path_prefixes: for lib in pkgutil.walk_packages([path], prefix=prefix): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: contents += [hashlib.sha256(f.read()).hexdigest()] # backend libtriton_hash = hashlib.sha256() ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f: while True: chunk = f.read(1024**2) if not chunk: break libtriton_hash.update(chunk) contents.append(libtriton_hash.hexdigest()) # language language_path = os.path.join(TRITON_PATH, 'language') for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: contents += [hashlib.sha256(f.read()).hexdigest()] return f'{__version__}' + '-'.join(contents) def get_cache_key(src, backend, backend_options, env_vars): key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}" return key