Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from mmengine.structures import InstanceData | |
| try: | |
| import motmetrics | |
| from motmetrics.lap import linear_sum_assignment | |
| except ImportError: | |
| motmetrics = None | |
| from torch import Tensor | |
| from mmdet.models.utils import imrenormalize | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import TrackDataSample | |
| from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcyah | |
| from mmdet.utils import OptConfigType | |
| from .sort_tracker import SORTTracker | |
| def cosine_distance(x: Tensor, y: Tensor) -> np.ndarray: | |
| """compute the cosine distance. | |
| Args: | |
| x (Tensor): embeddings with shape (N,C). | |
| y (Tensor): embeddings with shape (M,C). | |
| Returns: | |
| ndarray: cosine distance with shape (N,M). | |
| """ | |
| x = x.cpu().numpy() | |
| y = y.cpu().numpy() | |
| x = x / np.linalg.norm(x, axis=1, keepdims=True) | |
| y = y / np.linalg.norm(y, axis=1, keepdims=True) | |
| dists = 1. - np.dot(x, y.T) | |
| return dists | |
| class StrongSORTTracker(SORTTracker): | |
| """Tracker for StrongSORT. | |
| Args: | |
| obj_score_thr (float, optional): Threshold to filter the objects. | |
| Defaults to 0.6. | |
| motion (dict): Configuration of motion. Defaults to None. | |
| reid (dict, optional): Configuration for the ReID model. | |
| - num_samples (int, optional): Number of samples to calculate the | |
| feature embeddings of a track. Default to None. | |
| - image_scale (tuple, optional): Input scale of the ReID model. | |
| Default to (256, 128). | |
| - img_norm_cfg (dict, optional): Configuration to normalize the | |
| input. Default to None. | |
| - match_score_thr (float, optional): Similarity threshold for the | |
| matching process. Default to 0.3. | |
| - motion_weight (float, optional): the weight of the motion cost. | |
| Defaults to 0.02. | |
| match_iou_thr (float, optional): Threshold of the IoU matching process. | |
| Defaults to 0.7. | |
| num_tentatives (int, optional): Number of continuous frames to confirm | |
| a track. Defaults to 2. | |
| """ | |
| def __init__(self, | |
| motion: Optional[dict] = None, | |
| obj_score_thr: float = 0.6, | |
| reid: dict = dict( | |
| num_samples=None, | |
| img_scale=(256, 128), | |
| img_norm_cfg=None, | |
| match_score_thr=0.3, | |
| motion_weight=0.02), | |
| match_iou_thr: float = 0.7, | |
| num_tentatives: int = 2, | |
| **kwargs): | |
| if motmetrics is None: | |
| raise RuntimeError('motmetrics is not installed,\ | |
| please install it by: pip install motmetrics') | |
| super().__init__(motion, obj_score_thr, reid, match_iou_thr, | |
| num_tentatives, **kwargs) | |
| def update_track(self, id: int, obj: Tuple[Tensor]) -> None: | |
| """Update a track.""" | |
| for k, v in zip(self.memo_items, obj): | |
| v = v[None] | |
| if self.momentums is not None and k in self.momentums: | |
| m = self.momentums[k] | |
| self.tracks[id][k] = (1 - m) * self.tracks[id][k] + m * v | |
| else: | |
| self.tracks[id][k].append(v) | |
| if self.tracks[id].tentative: | |
| if len(self.tracks[id]['bboxes']) >= self.num_tentatives: | |
| self.tracks[id].tentative = False | |
| bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4) | |
| assert bbox.ndim == 2 and bbox.shape[0] == 1 | |
| bbox = bbox.squeeze(0).cpu().numpy() | |
| score = float(self.tracks[id].scores[-1].cpu()) | |
| self.tracks[id].mean, self.tracks[id].covariance = self.kf.update( | |
| self.tracks[id].mean, self.tracks[id].covariance, bbox, score) | |
| def track(self, | |
| model: torch.nn.Module, | |
| img: Tensor, | |
| data_sample: TrackDataSample, | |
| data_preprocessor: OptConfigType = None, | |
| rescale: bool = False, | |
| **kwargs) -> InstanceData: | |
| """Tracking forward function. | |
| Args: | |
| model (nn.Module): MOT model. | |
| img (Tensor): of shape (T, C, H, W) encoding input image. | |
| Typically these should be mean centered and std scaled. | |
| The T denotes the number of key images and usually is 1 in | |
| SORT method. | |
| feats (list[Tensor]): Multi level feature maps of `img`. | |
| data_sample (:obj:`TrackDataSample`): The data sample. | |
| It includes information such as `pred_det_instances`. | |
| data_preprocessor (dict or ConfigDict, optional): The pre-process | |
| config of :class:`TrackDataPreprocessor`. it usually includes, | |
| ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. | |
| rescale (bool, optional): If True, the bounding boxes should be | |
| rescaled to fit the original scale of the image. Defaults to | |
| False. | |
| Returns: | |
| :obj:`InstanceData`: Tracking results of the input images. | |
| Each InstanceData usually contains ``bboxes``, ``labels``, | |
| ``scores`` and ``instances_id``. | |
| """ | |
| metainfo = data_sample.metainfo | |
| bboxes = data_sample.pred_instances.bboxes | |
| labels = data_sample.pred_instances.labels | |
| scores = data_sample.pred_instances.scores | |
| frame_id = metainfo.get('frame_id', -1) | |
| if frame_id == 0: | |
| self.reset() | |
| if not hasattr(self, 'kf'): | |
| self.kf = self.motion | |
| if self.with_reid: | |
| if self.reid.get('img_norm_cfg', False): | |
| img_norm_cfg = dict( | |
| mean=data_preprocessor.get('mean', [0, 0, 0]), | |
| std=data_preprocessor.get('std', [1, 1, 1]), | |
| to_bgr=data_preprocessor.get('rgb_to_bgr', False)) | |
| reid_img = imrenormalize(img, img_norm_cfg, | |
| self.reid['img_norm_cfg']) | |
| else: | |
| reid_img = img.clone() | |
| valid_inds = scores > self.obj_score_thr | |
| bboxes = bboxes[valid_inds] | |
| labels = labels[valid_inds] | |
| scores = scores[valid_inds] | |
| if self.empty or bboxes.size(0) == 0: | |
| num_new_tracks = bboxes.size(0) | |
| ids = torch.arange( | |
| self.num_tracks, | |
| self.num_tracks + num_new_tracks, | |
| dtype=torch.long).to(bboxes.device) | |
| self.num_tracks += num_new_tracks | |
| if self.with_reid: | |
| crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(), | |
| rescale) | |
| if crops.size(0) > 0: | |
| embeds = model.reid(crops, mode='tensor') | |
| else: | |
| embeds = crops.new_zeros((0, model.reid.head.out_channels)) | |
| else: | |
| ids = torch.full((bboxes.size(0), ), -1, | |
| dtype=torch.long).to(bboxes.device) | |
| # motion | |
| if model.with_cmc: | |
| num_samples = 1 | |
| self.tracks = model.cmc.track(self.last_img, img, self.tracks, | |
| num_samples, frame_id, metainfo) | |
| self.tracks, motion_dists = self.motion.track( | |
| self.tracks, bbox_xyxy_to_cxcyah(bboxes)) | |
| active_ids = self.confirmed_ids | |
| if self.with_reid: | |
| crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(), | |
| rescale) | |
| embeds = model.reid(crops, mode='tensor') | |
| # reid | |
| if len(active_ids) > 0: | |
| track_embeds = self.get( | |
| 'embeds', | |
| active_ids, | |
| self.reid.get('num_samples', None), | |
| behavior='mean') | |
| reid_dists = cosine_distance(track_embeds, embeds) | |
| valid_inds = [list(self.ids).index(_) for _ in active_ids] | |
| reid_dists[~np.isfinite(motion_dists[ | |
| valid_inds, :])] = np.nan | |
| weight_motion = self.reid.get('motion_weight') | |
| match_dists = (1 - weight_motion) * reid_dists + \ | |
| weight_motion * motion_dists[valid_inds] | |
| # support multi-class association | |
| track_labels = torch.tensor([ | |
| self.tracks[id]['labels'][-1] for id in active_ids | |
| ]).to(bboxes.device) | |
| cate_match = labels[None, :] == track_labels[:, None] | |
| cate_cost = ((1 - cate_match.int()) * 1e6).cpu().numpy() | |
| match_dists = match_dists + cate_cost | |
| row, col = linear_sum_assignment(match_dists) | |
| for r, c in zip(row, col): | |
| dist = match_dists[r, c] | |
| if not np.isfinite(dist): | |
| continue | |
| if dist <= self.reid['match_score_thr']: | |
| ids[c] = active_ids[r] | |
| active_ids = [ | |
| id for id in self.ids if id not in ids | |
| and self.tracks[id].frame_ids[-1] == frame_id - 1 | |
| ] | |
| if len(active_ids) > 0: | |
| active_dets = torch.nonzero(ids == -1).squeeze(1) | |
| track_bboxes = self.get('bboxes', active_ids) | |
| ious = bbox_overlaps(track_bboxes, bboxes[active_dets]) | |
| # support multi-class association | |
| track_labels = torch.tensor([ | |
| self.tracks[id]['labels'][-1] for id in active_ids | |
| ]).to(bboxes.device) | |
| cate_match = labels[None, active_dets] == track_labels[:, None] | |
| cate_cost = (1 - cate_match.int()) * 1e6 | |
| dists = (1 - ious + cate_cost).cpu().numpy() | |
| row, col = linear_sum_assignment(dists) | |
| for r, c in zip(row, col): | |
| dist = dists[r, c] | |
| if dist < 1 - self.match_iou_thr: | |
| ids[active_dets[c]] = active_ids[r] | |
| new_track_inds = ids == -1 | |
| ids[new_track_inds] = torch.arange( | |
| self.num_tracks, | |
| self.num_tracks + new_track_inds.sum(), | |
| dtype=torch.long).to(bboxes.device) | |
| self.num_tracks += new_track_inds.sum() | |
| self.update( | |
| ids=ids, | |
| bboxes=bboxes, | |
| scores=scores, | |
| labels=labels, | |
| embeds=embeds if self.with_reid else None, | |
| frame_ids=frame_id) | |
| self.last_img = img | |
| # update pred_track_instances | |
| pred_track_instances = InstanceData() | |
| pred_track_instances.bboxes = bboxes | |
| pred_track_instances.labels = labels | |
| pred_track_instances.scores = scores | |
| pred_track_instances.instances_id = ids | |
| return pred_track_instances | |