Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmengine import Config | |
| from mmengine.registry import init_default_scope | |
| from mmengine.runner import load_checkpoint | |
| from mmengine.utils import mkdir_or_exist | |
| from mmseg.models import BaseSegmentor | |
| from mmseg.registry import MODELS | |
| from mmseg.structures import SegDataSample | |
| from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette | |
| from mmseg.visualization import SegLocalVisualizer | |
| from .utils import ImageType, _preprare_data | |
| def init_model(config: Union[str, Path, Config], | |
| checkpoint: Optional[str] = None, | |
| device: str = 'cuda:0', | |
| cfg_options: Optional[dict] = None): | |
| """Initialize a segmentor from config file. | |
| Args: | |
| config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, | |
| :obj:`Path`, or the config object. | |
| checkpoint (str, optional): Checkpoint path. If left as None, the model | |
| will not load any weights. | |
| device (str, optional) CPU/CUDA device option. Default 'cuda:0'. | |
| Use 'cpu' for loading model on CPU. | |
| cfg_options (dict, optional): Options to override some settings in | |
| the used config. | |
| Returns: | |
| nn.Module: The constructed segmentor. | |
| """ | |
| if isinstance(config, (str, Path)): | |
| config = Config.fromfile(config) | |
| elif not isinstance(config, Config): | |
| raise TypeError('config must be a filename or Config object, ' | |
| 'but got {}'.format(type(config))) | |
| if cfg_options is not None: | |
| config.merge_from_dict(cfg_options) | |
| if config.model.type == 'EncoderDecoder': | |
| if 'init_cfg' in config.model.backbone: | |
| config.model.backbone.init_cfg = None | |
| elif config.model.type == 'MultimodalEncoderDecoder': | |
| for k, v in config.model.items(): | |
| if isinstance(v, dict) and 'init_cfg' in v: | |
| config.model[k].init_cfg = None | |
| config.model.pretrained = None | |
| config.model.train_cfg = None | |
| init_default_scope(config.get('default_scope', 'mmseg')) | |
| model = MODELS.build(config.model) | |
| if checkpoint is not None: | |
| checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') | |
| dataset_meta = checkpoint['meta'].get('dataset_meta', None) | |
| # save the dataset_meta in the model for convenience | |
| if 'dataset_meta' in checkpoint.get('meta', {}): | |
| # mmseg 1.x | |
| model.dataset_meta = dataset_meta | |
| elif 'CLASSES' in checkpoint.get('meta', {}): | |
| # < mmseg 1.x | |
| classes = checkpoint['meta']['CLASSES'] | |
| palette = checkpoint['meta']['PALETTE'] | |
| model.dataset_meta = {'classes': classes, 'palette': palette} | |
| else: | |
| warnings.simplefilter('once') | |
| warnings.warn( | |
| 'dataset_meta or class names are not saved in the ' | |
| 'checkpoint\'s meta data, classes and palette will be' | |
| 'set according to num_classes ') | |
| num_classes = model.decode_head.num_classes | |
| dataset_name = None | |
| for name in dataset_aliases.keys(): | |
| if len(get_classes(name)) == num_classes: | |
| dataset_name = name | |
| break | |
| if dataset_name is None: | |
| warnings.warn( | |
| 'No suitable dataset found, use Cityscapes by default') | |
| dataset_name = 'cityscapes' | |
| model.dataset_meta = { | |
| 'classes': get_classes(dataset_name), | |
| 'palette': get_palette(dataset_name) | |
| } | |
| model.cfg = config # save the config in the model for convenience | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def inference_model(model: BaseSegmentor, | |
| img: ImageType) -> Union[SegDataSample, SampleList]: | |
| """Inference image(s) with the segmentor. | |
| Args: | |
| model (nn.Module): The loaded segmentor. | |
| imgs (str/ndarray or list[str/ndarray]): Either image files or loaded | |
| images. | |
| Returns: | |
| :obj:`SegDataSample` or list[:obj:`SegDataSample`]: | |
| If imgs is a list or tuple, the same length list type results | |
| will be returned, otherwise return the segmentation results directly. | |
| """ | |
| # prepare data | |
| data, is_batch = _preprare_data(img, model) | |
| # forward the model | |
| with torch.no_grad(): | |
| results = model.test_step(data) | |
| return results if is_batch else results[0] | |
| def show_result_pyplot(model: BaseSegmentor, | |
| img: Union[str, np.ndarray], | |
| result: SegDataSample, | |
| opacity: float = 0.5, | |
| title: str = '', | |
| draw_gt: bool = True, | |
| draw_pred: bool = True, | |
| wait_time: float = 0, | |
| show: bool = True, | |
| with_labels: Optional[bool] = True, | |
| save_dir=None, | |
| out_file=None): | |
| """Visualize the segmentation results on the image. | |
| Args: | |
| model (nn.Module): The loaded segmentor. | |
| img (str or np.ndarray): Image filename or loaded image. | |
| result (SegDataSample): The prediction SegDataSample result. | |
| opacity(float): Opacity of painted segmentation map. | |
| Default 0.5. Must be in (0, 1] range. | |
| title (str): The title of pyplot figure. | |
| Default is ''. | |
| draw_gt (bool): Whether to draw GT SegDataSample. Default to True. | |
| draw_pred (bool): Whether to draw Prediction SegDataSample. | |
| Defaults to True. | |
| wait_time (float): The interval of show (s). 0 is the special value | |
| that means "forever". Defaults to 0. | |
| show (bool): Whether to display the drawn image. | |
| Default to True. | |
| with_labels(bool, optional): Add semantic labels in visualization | |
| result, Default to True. | |
| save_dir (str, optional): Save file dir for all storage backends. | |
| If it is None, the backend storage will not save any data. | |
| out_file (str, optional): Path to output file. Default to None. | |
| Returns: | |
| np.ndarray: the drawn image which channel is RGB. | |
| """ | |
| if hasattr(model, 'module'): | |
| model = model.module | |
| if isinstance(img, str): | |
| image = mmcv.imread(img, channel_order='rgb') | |
| else: | |
| image = img | |
| if save_dir is not None: | |
| mkdir_or_exist(save_dir) | |
| # init visualizer | |
| visualizer = SegLocalVisualizer( | |
| vis_backends=[dict(type='LocalVisBackend')], | |
| save_dir=save_dir, | |
| alpha=opacity) | |
| visualizer.dataset_meta = dict( | |
| classes=model.dataset_meta['classes'], | |
| palette=model.dataset_meta['palette']) | |
| visualizer.add_datasample( | |
| name=title, | |
| image=image, | |
| data_sample=result, | |
| draw_gt=draw_gt, | |
| draw_pred=draw_pred, | |
| wait_time=wait_time, | |
| out_file=out_file, | |
| show=show, | |
| with_labels=with_labels) | |
| vis_img = visualizer.get_image() | |
| return vis_img | |