Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2023-present, the HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Contains utilities used by both the sync and async inference clients.""" | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| AsyncIterable, | |
| BinaryIO, | |
| ContextManager, | |
| Dict, | |
| Generator, | |
| Iterable, | |
| List, | |
| Literal, | |
| NoReturn, | |
| Optional, | |
| Union, | |
| overload, | |
| ) | |
| from requests import HTTPError | |
| from huggingface_hub.errors import ( | |
| GenerationError, | |
| IncompleteGenerationError, | |
| OverloadedError, | |
| TextGenerationError, | |
| UnknownError, | |
| ValidationError, | |
| ) | |
| from ..constants import ENDPOINT | |
| from ..utils import ( | |
| build_hf_headers, | |
| get_session, | |
| hf_raise_for_status, | |
| is_aiohttp_available, | |
| is_numpy_available, | |
| is_pillow_available, | |
| ) | |
| from ._generated.types import ( | |
| ChatCompletionStreamOutput, | |
| TextGenerationStreamOutput, | |
| ) | |
| if TYPE_CHECKING: | |
| from aiohttp import ClientResponse, ClientSession | |
| from PIL.Image import Image | |
| # TYPES | |
| UrlT = str | |
| PathT = Union[str, Path] | |
| BinaryT = Union[bytes, BinaryIO] | |
| ContentT = Union[BinaryT, PathT, UrlT] | |
| # Use to set a Accept: image/png header | |
| TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} | |
| logger = logging.getLogger(__name__) | |
| # Add dataclass for ModelStatus. We use this dataclass in get_model_status function. | |
| class ModelStatus: | |
| """ | |
| This Dataclass represents the the model status in the Hugging Face Inference API. | |
| Args: | |
| loaded (`bool`): | |
| If the model is currently loaded into Hugging Face's InferenceAPI. Models | |
| are loaded on-demand, leading to the user's first request taking longer. | |
| If a model is loaded, you can be assured that it is in a healthy state. | |
| state (`str`): | |
| The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'. | |
| If a model's state is 'Loadable', it's not too big and has a supported | |
| backend. Loadable models are automatically loaded when the user first | |
| requests inference on the endpoint. This means it is transparent for the | |
| user to load a model, except that the first call takes longer to complete. | |
| compute_type (`Dict`): | |
| Information about the compute resource the model is using or will use, such as 'gpu' type and number of | |
| replicas. | |
| framework (`str`): | |
| The name of the framework that the model was built with, such as 'transformers' | |
| or 'text-generation-inference'. | |
| """ | |
| loaded: bool | |
| state: str | |
| compute_type: Dict | |
| framework: str | |
| ## IMPORT UTILS | |
| def _import_aiohttp(): | |
| # Make sure `aiohttp` is installed on the machine. | |
| if not is_aiohttp_available(): | |
| raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).") | |
| import aiohttp | |
| return aiohttp | |
| def _import_numpy(): | |
| """Make sure `numpy` is installed on the machine.""" | |
| if not is_numpy_available(): | |
| raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") | |
| import numpy | |
| return numpy | |
| def _import_pil_image(): | |
| """Make sure `PIL` is installed on the machine.""" | |
| if not is_pillow_available(): | |
| raise ImportError( | |
| "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" | |
| " post-processed, use `client.post(...)` and get the raw response from the server." | |
| ) | |
| from PIL import Image | |
| return Image | |
| ## RECOMMENDED MODELS | |
| # Will be globally fetched only once (see '_fetch_recommended_models') | |
| _RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None | |
| def _fetch_recommended_models() -> Dict[str, Optional[str]]: | |
| global _RECOMMENDED_MODELS | |
| if _RECOMMENDED_MODELS is None: | |
| response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers()) | |
| hf_raise_for_status(response) | |
| _RECOMMENDED_MODELS = { | |
| task: _first_or_none(details["widgetModels"]) for task, details in response.json().items() | |
| } | |
| return _RECOMMENDED_MODELS | |
| def _first_or_none(items: List[Any]) -> Optional[Any]: | |
| try: | |
| return items[0] or None | |
| except IndexError: | |
| return None | |
| ## ENCODING / DECODING UTILS | |
| def _open_as_binary( | |
| content: ContentT, | |
| ) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None" | |
| def _open_as_binary( | |
| content: Literal[None], | |
| ) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None" | |
| # type: ignore | |
| def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]: | |
| """Open `content` as a binary file, either from a URL, a local path, or raw bytes. | |
| Do nothing if `content` is None, | |
| TODO: handle a PIL.Image as input | |
| TODO: handle base64 as input | |
| """ | |
| # If content is a string => must be either a URL or a path | |
| if isinstance(content, str): | |
| if content.startswith("https://") or content.startswith("http://"): | |
| logger.debug(f"Downloading content from {content}") | |
| yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ? | |
| return | |
| content = Path(content) | |
| if not content.exists(): | |
| raise FileNotFoundError( | |
| f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" | |
| " file. To pass raw content, please encode it as bytes first." | |
| ) | |
| # If content is a Path => open it | |
| if isinstance(content, Path): | |
| logger.debug(f"Opening content from {content}") | |
| with content.open("rb") as f: | |
| yield f | |
| else: | |
| # Otherwise: already a file-like object or None | |
| yield content | |
| def _b64_encode(content: ContentT) -> str: | |
| """Encode a raw file (image, audio) into base64. Can be byes, an opened file, a path or a URL.""" | |
| with _open_as_binary(content) as data: | |
| data_as_bytes = data if isinstance(data, bytes) else data.read() | |
| return base64.b64encode(data_as_bytes).decode() | |
| def _b64_to_image(encoded_image: str) -> "Image": | |
| """Parse a base64-encoded string into a PIL Image.""" | |
| Image = _import_pil_image() | |
| return Image.open(io.BytesIO(base64.b64decode(encoded_image))) | |
| def _bytes_to_list(content: bytes) -> List: | |
| """Parse bytes from a Response object into a Python list. | |
| Expects the response body to be JSON-encoded data. | |
| NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a | |
| dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. | |
| """ | |
| return json.loads(content.decode()) | |
| def _bytes_to_dict(content: bytes) -> Dict: | |
| """Parse bytes from a Response object into a Python dictionary. | |
| Expects the response body to be JSON-encoded data. | |
| NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a | |
| list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. | |
| """ | |
| return json.loads(content.decode()) | |
| def _bytes_to_image(content: bytes) -> "Image": | |
| """Parse bytes from a Response object into a PIL Image. | |
| Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. | |
| """ | |
| Image = _import_pil_image() | |
| return Image.open(io.BytesIO(content)) | |
| ## STREAMING UTILS | |
| def _stream_text_generation_response( | |
| bytes_output_as_lines: Iterable[bytes], details: bool | |
| ) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: | |
| """Used in `InferenceClient.text_generation`.""" | |
| # Parse ServerSentEvents | |
| for byte_payload in bytes_output_as_lines: | |
| try: | |
| output = _format_text_generation_stream_output(byte_payload, details) | |
| except StopIteration: | |
| break | |
| if output is not None: | |
| yield output | |
| async def _async_stream_text_generation_response( | |
| bytes_output_as_lines: AsyncIterable[bytes], details: bool | |
| ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: | |
| """Used in `AsyncInferenceClient.text_generation`.""" | |
| # Parse ServerSentEvents | |
| async for byte_payload in bytes_output_as_lines: | |
| try: | |
| output = _format_text_generation_stream_output(byte_payload, details) | |
| except StopIteration: | |
| break | |
| if output is not None: | |
| yield output | |
| def _format_text_generation_stream_output( | |
| byte_payload: bytes, details: bool | |
| ) -> Optional[Union[str, TextGenerationStreamOutput]]: | |
| if not byte_payload.startswith(b"data:"): | |
| return None # empty line | |
| if byte_payload.strip() == b"data: [DONE]": | |
| raise StopIteration("[DONE] signal received.") | |
| # Decode payload | |
| payload = byte_payload.decode("utf-8") | |
| json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
| # Either an error as being returned | |
| if json_payload.get("error") is not None: | |
| raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) | |
| # Or parse token payload | |
| output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload) | |
| return output.token.text if not details else output | |
| def _stream_chat_completion_response( | |
| bytes_lines: Iterable[bytes], | |
| ) -> Iterable[ChatCompletionStreamOutput]: | |
| """Used in `InferenceClient.chat_completion` if model is served with TGI.""" | |
| for item in bytes_lines: | |
| try: | |
| output = _format_chat_completion_stream_output(item) | |
| except StopIteration: | |
| break | |
| if output is not None: | |
| yield output | |
| async def _async_stream_chat_completion_response( | |
| bytes_lines: AsyncIterable[bytes], | |
| ) -> AsyncIterable[ChatCompletionStreamOutput]: | |
| """Used in `AsyncInferenceClient.chat_completion`.""" | |
| async for item in bytes_lines: | |
| try: | |
| output = _format_chat_completion_stream_output(item) | |
| except StopIteration: | |
| break | |
| if output is not None: | |
| yield output | |
| def _format_chat_completion_stream_output( | |
| byte_payload: bytes, | |
| ) -> Optional[ChatCompletionStreamOutput]: | |
| if not byte_payload.startswith(b"data:"): | |
| return None # empty line | |
| if byte_payload.strip() == b"data: [DONE]": | |
| raise StopIteration("[DONE] signal received.") | |
| # Decode payload | |
| payload = byte_payload.decode("utf-8") | |
| json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
| # Either an error as being returned | |
| if json_payload.get("error") is not None: | |
| raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) | |
| # Or parse token payload | |
| return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) | |
| async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: | |
| async for byte_payload in response.content: | |
| yield byte_payload.strip() | |
| await client.close() | |
| # "TGI servers" are servers running with the `text-generation-inference` backend. | |
| # This backend is the go-to solution to run large language models at scale. However, | |
| # for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference` | |
| # solution is still in use. | |
| # | |
| # Both approaches have very similar APIs, but not exactly the same. What we do first in | |
| # the `text_generation` method is to assume the model is served via TGI. If we realize | |
| # it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the | |
| # default API with a warning message. When that's the case, We remember the unsupported | |
| # attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable. | |
| # | |
| # In addition, TGI servers have a built-in API route for chat-completion, which is not | |
| # available on the default API. We use this route to provide a more consistent behavior | |
| # when available. | |
| # | |
| # For more details, see https://github.com/huggingface/text-generation-inference and | |
| # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task. | |
| _UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {} | |
| def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None: | |
| _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) | |
| def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: | |
| return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) | |
| # TEXT GENERATION ERRORS | |
| # ---------------------- | |
| # Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation | |
| # inference project (https://github.com/huggingface/text-generation-inference). | |
| # ---------------------- | |
| def raise_text_generation_error(http_error: HTTPError) -> NoReturn: | |
| """ | |
| Try to parse text-generation-inference error message and raise HTTPError in any case. | |
| Args: | |
| error (`HTTPError`): | |
| The HTTPError that have been raised. | |
| """ | |
| # Try to parse a Text Generation Inference error | |
| try: | |
| # Hacky way to retrieve payload in case of aiohttp error | |
| payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() | |
| error = payload.get("error") | |
| error_type = payload.get("error_type") | |
| except Exception: # no payload | |
| raise http_error | |
| # If error_type => more information than `hf_raise_for_status` | |
| if error_type is not None: | |
| exception = _parse_text_generation_error(error, error_type) | |
| raise exception from http_error | |
| # Otherwise, fallback to default error | |
| raise http_error | |
| def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: | |
| if error_type == "generation": | |
| return GenerationError(error) # type: ignore | |
| if error_type == "incomplete_generation": | |
| return IncompleteGenerationError(error) # type: ignore | |
| if error_type == "overloaded": | |
| return OverloadedError(error) # type: ignore | |
| if error_type == "validation": | |
| return ValidationError(error) # type: ignore | |
| return UnknownError(error) # type: ignore | |