Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Literal, Optional, Union | |
| import requests | |
| from tqdm.auto import tqdm as base_tqdm | |
| from tqdm.contrib.concurrent import thread_map | |
| from . import constants | |
| from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError | |
| from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name | |
| from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo | |
| from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args | |
| from .utils import tqdm as hf_tqdm | |
| logger = logging.get_logger(__name__) | |
| def snapshot_download( | |
| repo_id: str, | |
| *, | |
| repo_type: Optional[str] = None, | |
| revision: Optional[str] = None, | |
| cache_dir: Union[str, Path, None] = None, | |
| local_dir: Union[str, Path, None] = None, | |
| library_name: Optional[str] = None, | |
| library_version: Optional[str] = None, | |
| user_agent: Optional[Union[Dict, str]] = None, | |
| proxies: Optional[Dict] = None, | |
| etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, | |
| force_download: bool = False, | |
| token: Optional[Union[bool, str]] = None, | |
| local_files_only: bool = False, | |
| allow_patterns: Optional[Union[List[str], str]] = None, | |
| ignore_patterns: Optional[Union[List[str], str]] = None, | |
| max_workers: int = 8, | |
| tqdm_class: Optional[base_tqdm] = None, | |
| headers: Optional[Dict[str, str]] = None, | |
| endpoint: Optional[str] = None, | |
| # Deprecated args | |
| local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", | |
| resume_download: Optional[bool] = None, | |
| ) -> str: | |
| """Download repo files. | |
| Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from | |
| a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order | |
| to keep their actual filename relative to that folder. You can also filter which files to download using | |
| `allow_patterns` and `ignore_patterns`. | |
| If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this | |
| option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` | |
| to store some metadata related to the downloaded files. While this mechanism is not as robust as the main | |
| cache-system, it's optimized for regularly pulling the latest version of a repository. | |
| An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly | |
| configured. It is also not possible to filter which files to download when cloning a repository using git. | |
| Args: | |
| repo_id (`str`): | |
| A user or an organization name and a repo name separated by a `/`. | |
| repo_type (`str`, *optional*): | |
| Set to `"dataset"` or `"space"` if downloading from a dataset or space, | |
| `None` or `"model"` if downloading from a model. Default is `None`. | |
| revision (`str`, *optional*): | |
| An optional Git revision id which can be a branch name, a tag, or a | |
| commit hash. | |
| cache_dir (`str`, `Path`, *optional*): | |
| Path to the folder where cached files are stored. | |
| local_dir (`str` or `Path`, *optional*): | |
| If provided, the downloaded files will be placed under this directory. | |
| library_name (`str`, *optional*): | |
| The name of the library to which the object corresponds. | |
| library_version (`str`, *optional*): | |
| The version of the library. | |
| user_agent (`str`, `dict`, *optional*): | |
| The user-agent info in the form of a dictionary or a string. | |
| proxies (`dict`, *optional*): | |
| Dictionary mapping protocol to the URL of the proxy passed to | |
| `requests.request`. | |
| etag_timeout (`float`, *optional*, defaults to `10`): | |
| When fetching ETag, how many seconds to wait for the server to send | |
| data before giving up which is passed to `requests.request`. | |
| force_download (`bool`, *optional*, defaults to `False`): | |
| Whether the file should be downloaded even if it already exists in the local cache. | |
| token (`str`, `bool`, *optional*): | |
| A token to be used for the download. | |
| - If `True`, the token is read from the HuggingFace config | |
| folder. | |
| - If a string, it's used as the authentication token. | |
| headers (`dict`, *optional*): | |
| Additional headers to include in the request. Those headers take precedence over the others. | |
| local_files_only (`bool`, *optional*, defaults to `False`): | |
| If `True`, avoid downloading the file and return the path to the | |
| local cached file if it exists. | |
| allow_patterns (`List[str]` or `str`, *optional*): | |
| If provided, only files matching at least one pattern are downloaded. | |
| ignore_patterns (`List[str]` or `str`, *optional*): | |
| If provided, files matching any of the patterns are not downloaded. | |
| max_workers (`int`, *optional*): | |
| Number of concurrent threads to download files (1 thread = 1 file download). | |
| Defaults to 8. | |
| tqdm_class (`tqdm`, *optional*): | |
| If provided, overwrites the default behavior for the progress bar. Passed | |
| argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. | |
| Note that the `tqdm_class` is not passed to each individual download. | |
| Defaults to the custom HF progress bar that can be disabled by setting | |
| `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. | |
| Returns: | |
| `str`: folder path of the repo snapshot. | |
| Raises: | |
| [`~utils.RepositoryNotFoundError`] | |
| If the repository to download from cannot be found. This may be because it doesn't exist, | |
| or because it is set to `private` and you do not have access. | |
| [`~utils.RevisionNotFoundError`] | |
| If the revision to download from cannot be found. | |
| [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) | |
| If `token=True` and the token cannot be found. | |
| [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if | |
| ETag cannot be determined. | |
| [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) | |
| if some parameter value is invalid. | |
| """ | |
| if cache_dir is None: | |
| cache_dir = constants.HF_HUB_CACHE | |
| if revision is None: | |
| revision = constants.DEFAULT_REVISION | |
| if isinstance(cache_dir, Path): | |
| cache_dir = str(cache_dir) | |
| if repo_type is None: | |
| repo_type = "model" | |
| if repo_type not in constants.REPO_TYPES: | |
| raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}") | |
| storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) | |
| repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None | |
| api_call_error: Optional[Exception] = None | |
| if not local_files_only: | |
| # try/except logic to handle different errors => taken from `hf_hub_download` | |
| try: | |
| # if we have internet connection we want to list files to download | |
| api = HfApi( | |
| library_name=library_name, | |
| library_version=library_version, | |
| user_agent=user_agent, | |
| endpoint=endpoint, | |
| headers=headers, | |
| ) | |
| repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) | |
| except (requests.exceptions.SSLError, requests.exceptions.ProxyError): | |
| # Actually raise for those subclasses of ConnectionError | |
| raise | |
| except ( | |
| requests.exceptions.ConnectionError, | |
| requests.exceptions.Timeout, | |
| OfflineModeIsEnabled, | |
| ) as error: | |
| # Internet connection is down | |
| # => will try to use local files only | |
| api_call_error = error | |
| pass | |
| except RevisionNotFoundError: | |
| # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) | |
| raise | |
| except requests.HTTPError as error: | |
| # Multiple reasons for an http error: | |
| # - Repository is private and invalid/missing token sent | |
| # - Repository is gated and invalid/missing token sent | |
| # - Hub is down (error 500 or 504) | |
| # => let's switch to 'local_files_only=True' to check if the files are already cached. | |
| # (if it's not the case, the error will be re-raised) | |
| api_call_error = error | |
| pass | |
| # At this stage, if `repo_info` is None it means either: | |
| # - internet connection is down | |
| # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True) | |
| # - repo is private/gated and invalid/missing token sent | |
| # - Hub is down | |
| # => let's look if we can find the appropriate folder in the cache: | |
| # - if the specified revision is a commit hash, look inside "snapshots". | |
| # - f the specified revision is a branch or tag, look inside "refs". | |
| # => if local_dir is not None, we will return the path to the local folder if it exists. | |
| if repo_info is None: | |
| # Try to get which commit hash corresponds to the specified revision | |
| commit_hash = None | |
| if REGEX_COMMIT_HASH.match(revision): | |
| commit_hash = revision | |
| else: | |
| ref_path = os.path.join(storage_folder, "refs", revision) | |
| if os.path.exists(ref_path): | |
| # retrieve commit_hash from refs file | |
| with open(ref_path) as f: | |
| commit_hash = f.read() | |
| # Try to locate snapshot folder for this commit hash | |
| if commit_hash is not None: | |
| snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) | |
| if os.path.exists(snapshot_folder): | |
| # Snapshot folder exists => let's return it | |
| # (but we can't check if all the files are actually there) | |
| return snapshot_folder | |
| # If local_dir is not None, return it if it exists and is not empty | |
| if local_dir is not None: | |
| local_dir = Path(local_dir) | |
| if local_dir.is_dir() and any(local_dir.iterdir()): | |
| logger.warning( | |
| f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})." | |
| ) | |
| return str(local_dir.resolve()) | |
| # If we couldn't find the appropriate folder on disk, raise an error. | |
| if local_files_only: | |
| raise LocalEntryNotFoundError( | |
| "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " | |
| "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass " | |
| "'local_files_only=False' as input." | |
| ) | |
| elif isinstance(api_call_error, OfflineModeIsEnabled): | |
| raise LocalEntryNotFoundError( | |
| "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " | |
| "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set " | |
| "'HF_HUB_OFFLINE=0' as environment variable." | |
| ) from api_call_error | |
| elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError): | |
| # Repo not found => let's raise the actual error | |
| raise api_call_error | |
| else: | |
| # Otherwise: most likely a connection issue or Hub downtime => let's warn the user | |
| raise LocalEntryNotFoundError( | |
| "An error happened while trying to locate the files on the Hub and we cannot find the appropriate" | |
| " snapshot folder for the specified revision on the local disk. Please check your internet connection" | |
| " and try again." | |
| ) from api_call_error | |
| # At this stage, internet connection is up and running | |
| # => let's download the files! | |
| assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." | |
| assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." | |
| filtered_repo_files = list( | |
| filter_repo_objects( | |
| items=[f.rfilename for f in repo_info.siblings], | |
| allow_patterns=allow_patterns, | |
| ignore_patterns=ignore_patterns, | |
| ) | |
| ) | |
| commit_hash = repo_info.sha | |
| snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) | |
| # if passed revision is not identical to commit_hash | |
| # then revision has to be a branch name or tag name. | |
| # In that case store a ref. | |
| if revision != commit_hash: | |
| ref_path = os.path.join(storage_folder, "refs", revision) | |
| os.makedirs(os.path.dirname(ref_path), exist_ok=True) | |
| with open(ref_path, "w") as f: | |
| f.write(commit_hash) | |
| # we pass the commit_hash to hf_hub_download | |
| # so no network call happens if we already | |
| # have the file locally. | |
| def _inner_hf_hub_download(repo_file: str): | |
| return hf_hub_download( | |
| repo_id, | |
| filename=repo_file, | |
| repo_type=repo_type, | |
| revision=commit_hash, | |
| endpoint=endpoint, | |
| cache_dir=cache_dir, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=local_dir_use_symlinks, | |
| library_name=library_name, | |
| library_version=library_version, | |
| user_agent=user_agent, | |
| proxies=proxies, | |
| etag_timeout=etag_timeout, | |
| resume_download=resume_download, | |
| force_download=force_download, | |
| token=token, | |
| headers=headers, | |
| ) | |
| if constants.HF_HUB_ENABLE_HF_TRANSFER: | |
| # when using hf_transfer we don't want extra parallelism | |
| # from the one hf_transfer provides | |
| for file in filtered_repo_files: | |
| _inner_hf_hub_download(file) | |
| else: | |
| thread_map( | |
| _inner_hf_hub_download, | |
| filtered_repo_files, | |
| desc=f"Fetching {len(filtered_repo_files)} files", | |
| max_workers=max_workers, | |
| # User can use its own tqdm class or the default one from `huggingface_hub.utils` | |
| tqdm_class=tqdm_class or hf_tqdm, | |
| ) | |
| if local_dir is not None: | |
| return str(os.path.realpath(local_dir)) | |
| return snapshot_folder | |