Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Sequence | |
| from mmpretrain.registry import METRICS | |
| from mmpretrain.structures import label_to_onehot | |
| from .multi_label import AveragePrecision, MultiLabelMetric | |
| class VOCMetricMixin: | |
| """A mixin class for VOC dataset metrics, VOC annotations have extra | |
| `difficult` attribute for each object, therefore, extra option is needed | |
| for calculating VOC metrics. | |
| Args: | |
| difficult_as_postive (Optional[bool]): Whether to map the difficult | |
| labels as positive in one-hot ground truth for evaluation. If it | |
| set to True, map difficult gt labels to positive ones(1), If it | |
| set to False, map difficult gt labels to negative ones(0). | |
| Defaults to None, the difficult labels will be set to '-1'. | |
| """ | |
| def __init__(self, | |
| *arg, | |
| difficult_as_positive: Optional[bool] = None, | |
| **kwarg): | |
| self.difficult_as_positive = difficult_as_positive | |
| super().__init__(*arg, **kwarg) | |
| def process(self, data_batch, data_samples: Sequence[dict]): | |
| """Process one batch of data samples. | |
| The processed results should be stored in ``self.results``, which will | |
| be used to computed the metrics when all batches have been processed. | |
| Args: | |
| data_batch: A batch of data from the dataloader. | |
| data_samples (Sequence[dict]): A batch of outputs from the model. | |
| """ | |
| for data_sample in data_samples: | |
| result = dict() | |
| gt_label = data_sample['gt_label'] | |
| gt_label_difficult = data_sample['gt_label_difficult'] | |
| result['pred_score'] = data_sample['pred_score'].clone() | |
| num_classes = result['pred_score'].size()[-1] | |
| if 'gt_score' in data_sample: | |
| result['gt_score'] = data_sample['gt_score'].clone() | |
| else: | |
| result['gt_score'] = label_to_onehot(gt_label, num_classes) | |
| # VOC annotation labels all the objects in a single image | |
| # therefore, some categories are appeared both in | |
| # difficult objects and non-difficult objects. | |
| # Here we reckon those labels which are only exists in difficult | |
| # objects as difficult labels. | |
| difficult_label = set(gt_label_difficult) - ( | |
| set(gt_label_difficult) & set(gt_label.tolist())) | |
| # set difficult label for better eval | |
| if self.difficult_as_positive is None: | |
| result['gt_score'][[*difficult_label]] = -1 | |
| elif self.difficult_as_positive: | |
| result['gt_score'][[*difficult_label]] = 1 | |
| # Save the result to `self.results`. | |
| self.results.append(result) | |
| class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric): | |
| """A collection of metrics for multi-label multi-class classification task | |
| based on confusion matrix for VOC dataset. | |
| It includes precision, recall, f1-score and support. | |
| Args: | |
| difficult_as_postive (Optional[bool]): Whether to map the difficult | |
| labels as positive in one-hot ground truth for evaluation. If it | |
| set to True, map difficult gt labels to positive ones(1), If it | |
| set to False, map difficult gt labels to negative ones(0). | |
| Defaults to None, the difficult labels will be set to '-1'. | |
| **kwarg: Refers to `MultiLabelMetric` for detailed docstrings. | |
| """ | |
| class VOCAveragePrecision(VOCMetricMixin, AveragePrecision): | |
| """Calculate the average precision with respect of classes for VOC dataset. | |
| Args: | |
| difficult_as_postive (Optional[bool]): Whether to map the difficult | |
| labels as positive in one-hot ground truth for evaluation. If it | |
| set to True, map difficult gt labels to positive ones(1), If it | |
| set to False, map difficult gt labels to negative ones(0). | |
| Defaults to None, the difficult labels will be set to '-1'. | |
| **kwarg: Refers to `AveragePrecision` for detailed docstrings. | |
| """ | |