Spaces:
Sleeping
Sleeping
| """ | |
| Code adapted from timm https://github.com/huggingface/pytorch-image-models | |
| Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich | |
| """ | |
| import os | |
| from typing import Any, Dict, Optional, Union | |
| import timm | |
| # register new models | |
| from mivolo.model.mivolo_model import * # noqa: F403, F401 | |
| from timm.layers import set_layer_config | |
| from timm.models._factory import parse_model_name | |
| from timm.models._helpers import load_state_dict, remap_checkpoint | |
| from timm.models._hub import load_model_config_from_hf | |
| from timm.models._pretrained import PretrainedCfg, split_model_name_tag | |
| from timm.models._registry import is_model, model_entrypoint | |
| def load_checkpoint( | |
| model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None | |
| ): | |
| if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"): | |
| # numpy checkpoint, try to load via model specific load_pretrained fn | |
| if hasattr(model, "load_pretrained"): | |
| timm.models._model_builder.load_pretrained(checkpoint_path) | |
| else: | |
| raise NotImplementedError("Model cannot load numpy checkpoint") | |
| return | |
| state_dict = load_state_dict(checkpoint_path, use_ema) | |
| if remap: | |
| state_dict = remap_checkpoint(model, state_dict) | |
| if filter_keys: | |
| for sd_key in list(state_dict.keys()): | |
| for filter_key in filter_keys: | |
| if filter_key in sd_key: | |
| if sd_key in state_dict: | |
| del state_dict[sd_key] | |
| rep = [] | |
| if state_dict_map is not None: | |
| # 'patch_embed.conv1.' : 'patch_embed.conv.' | |
| for state_k in list(state_dict.keys()): | |
| for target_k, target_v in state_dict_map.items(): | |
| if target_v in state_k: | |
| target_name = state_k.replace(target_v, target_k) | |
| state_dict[target_name] = state_dict[state_k] | |
| rep.append(state_k) | |
| for r in rep: | |
| if r in state_dict: | |
| del state_dict[r] | |
| incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False) | |
| return incompatible_keys | |
| def create_model( | |
| model_name: str, | |
| pretrained: bool = False, | |
| pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, | |
| pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, | |
| checkpoint_path: str = "", | |
| scriptable: Optional[bool] = None, | |
| exportable: Optional[bool] = None, | |
| no_jit: Optional[bool] = None, | |
| filter_keys=None, | |
| state_dict_map=None, | |
| **kwargs, | |
| ): | |
| """Create a model | |
| Lookup model's entrypoint function and pass relevant args to create a new model. | |
| """ | |
| # Parameters that aren't supported by all models or are intended to only override model defaults if set | |
| # should default to None in command line args/cfg. Remove them if they are present and not set so that | |
| # non-supporting models don't break and default args remain in effect. | |
| kwargs = {k: v for k, v in kwargs.items() if v is not None} | |
| model_source, model_name = parse_model_name(model_name) | |
| if model_source == "hf-hub": | |
| assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub." | |
| # For model names specified in the form `hf-hub:path/architecture_name@revision`, | |
| # load model weights + pretrained_cfg from Hugging Face hub. | |
| pretrained_cfg, model_name = load_model_config_from_hf(model_name) | |
| else: | |
| model_name, pretrained_tag = split_model_name_tag(model_name) | |
| if not pretrained_cfg: | |
| # a valid pretrained_cfg argument takes priority over tag in model name | |
| pretrained_cfg = pretrained_tag | |
| if not is_model(model_name): | |
| raise RuntimeError("Unknown model (%s)" % model_name) | |
| create_fn = model_entrypoint(model_name) | |
| with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): | |
| model = create_fn( | |
| pretrained=pretrained, | |
| pretrained_cfg=pretrained_cfg, | |
| pretrained_cfg_overlay=pretrained_cfg_overlay, | |
| **kwargs, | |
| ) | |
| if checkpoint_path: | |
| load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map) | |
| return model | |