import json from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional, Union from urllib.parse import urlparse, urlunparse from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import ( MimeBytes, RequestParameters, _b64_encode, _bytes_to_dict, _open_as_mime_bytes, ) from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status class HFInferenceTask(TaskProviderHelper): """Base class for HF Inference API tasks.""" def __init__(self, task: str): super().__init__( provider="hf-inference", base_url=constants.INFERENCE_PROXY_TEMPLATE.format(provider="hf-inference"), task=task, ) def _prepare_api_key(self, api_key: Optional[str]) -> str: # special case: for HF Inference we allow not providing an API key return api_key or get_token() # type: ignore[return-value] def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: if model is not None and model.startswith(("http://", "https://")): return InferenceProviderMapping( provider="hf-inference", providerId=model, hf_model_id=model, task=self.task, status="live" ) model_id = model if model is not None else _fetch_recommended_models().get(self.task) if model_id is None: raise ValueError( f"Task {self.task} has no recommended model for HF Inference. Please specify a model" " explicitly. Visit https://huggingface.co/tasks for more info." ) _check_supported_task(model_id, self.task) return InferenceProviderMapping( provider="hf-inference", providerId=model_id, hf_model_id=model_id, task=self.task, status="live" ) def _prepare_url(self, api_key: str, mapped_model: str) -> str: # hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment) if mapped_model.startswith(("http://", "https://")): return mapped_model return ( # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks. f"{self.base_url}/models/{mapped_model}/pipeline/{self.task}" if self.task in ("feature-extraction", "sentence-similarity") # Otherwise, we use the default endpoint else f"{self.base_url}/models/{mapped_model}" ) def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: if isinstance(inputs, bytes): raise ValueError(f"Unexpected binary input for task {self.task}.") if isinstance(inputs, Path): raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})") return filter_none({"inputs": inputs, "parameters": parameters}) class HFInferenceBinaryInputTask(HFInferenceTask): def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: return None def _prepare_payload_as_bytes( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping, extra_payload: Optional[Dict], ) -> Optional[MimeBytes]: parameters = filter_none(parameters) extra_payload = extra_payload or {} has_parameters = len(parameters) > 0 or len(extra_payload) > 0 # Raise if not a binary object or a local path or a URL. if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str): raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}") # Send inputs as raw content when no parameters are provided if not has_parameters: return _open_as_mime_bytes(inputs) # Otherwise encode as b64 return MimeBytes( json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8"), mime_type="application/json", ) class HFInferenceConversational(HFInferenceTask): def __init__(self): super().__init__("conversational") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: payload = filter_none(parameters) mapped_model = provider_mapping_info.provider_id payload_model = parameters.get("model") or mapped_model if payload_model is None or payload_model.startswith(("http://", "https://")): payload_model = "dummy" response_format = parameters.get("response_format") if isinstance(response_format, dict) and response_format.get("type") == "json_schema": payload["response_format"] = { "type": "json_object", "value": response_format["json_schema"]["schema"], } return {**payload, "model": payload_model, "messages": inputs} def _prepare_url(self, api_key: str, mapped_model: str) -> str: base_url = ( mapped_model if mapped_model.startswith(("http://", "https://")) else f"{constants.INFERENCE_PROXY_TEMPLATE.format(provider='hf-inference')}/models/{mapped_model}" ) return _build_chat_completion_url(base_url) def _build_chat_completion_url(model_url: str) -> str: parsed = urlparse(model_url) path = parsed.path.rstrip("/") # If the path already ends with /chat/completions, we're done! if path.endswith("/chat/completions"): return model_url # Append /chat/completions if not already present if path.endswith("/v1"): new_path = path + "/chat/completions" # If path was empty or just "/", set the full path elif not path: new_path = "/v1/chat/completions" # Append /v1/chat/completions if not already present else: new_path = path + "/v1/chat/completions" # Reconstruct the URL with the new path and original query parameters. new_parsed = parsed._replace(path=new_path) return str(urlunparse(new_parsed)) @lru_cache(maxsize=1) def _fetch_recommended_models() -> Dict[str, Optional[str]]: response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers()) hf_raise_for_status(response) return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()} @lru_cache(maxsize=None) def _check_supported_task(model: str, task: str) -> None: from huggingface_hub.hf_api import HfApi model_info = HfApi().model_info(model) pipeline_tag = model_info.pipeline_tag tags = model_info.tags or [] is_conversational = "conversational" in tags if task in ("text-generation", "conversational"): if pipeline_tag == "text-generation": # text-generation + conversational tag -> both tasks allowed if is_conversational: return # text-generation without conversational tag -> only text-generation allowed if task == "text-generation": return raise ValueError(f"Model '{model}' doesn't support task '{task}'.") if pipeline_tag == "text2text-generation": if task == "text-generation": return raise ValueError(f"Model '{model}' doesn't support task '{task}'.") if pipeline_tag == "image-text-to-text": if is_conversational and task == "conversational": return # Only conversational allowed if tagged as conversational raise ValueError("Non-conversational image-text-to-text task is not supported.") if ( task in ("feature-extraction", "sentence-similarity") and pipeline_tag in ("feature-extraction", "sentence-similarity") and task in tags ): # feature-extraction and sentence-similarity are interchangeable for HF Inference return # For all other tasks, just check pipeline tag if pipeline_tag != task: raise ValueError( f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'" ) return class HFInferenceFeatureExtractionTask(HFInferenceTask): def __init__(self): super().__init__("feature-extraction") def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: if isinstance(inputs, bytes): raise ValueError(f"Unexpected binary input for task {self.task}.") if isinstance(inputs, Path): raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})") # Parameters are sent at root-level for feature-extraction task # See specs: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/feature-extraction/spec/input.json return {"inputs": inputs, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: if isinstance(response, bytes): return _bytes_to_dict(response) return response