Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # Partly adopted from https://github.com/GT-Vision-Lab/VQA | |
| # Copyright (c) 2014, Aishwarya Agrawal | |
| from typing import List, Optional | |
| import mmengine | |
| from mmengine.evaluator import BaseMetric | |
| from mmengine.logging import MMLogger | |
| from mmpretrain.registry import METRICS | |
| def _process_punctuation(inText): | |
| import re | |
| outText = inText | |
| punct = [ | |
| ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', | |
| '>', '<', '@', '`', ',', '?', '!' | |
| ] | |
| commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 | |
| periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 | |
| for p in punct: | |
| if (p + ' ' in inText or ' ' + p in inText) or (re.search( | |
| commaStrip, inText) is not None): | |
| outText = outText.replace(p, '') | |
| else: | |
| outText = outText.replace(p, ' ') | |
| outText = periodStrip.sub('', outText, re.UNICODE) | |
| return outText | |
| def _process_digit_article(inText): | |
| outText = [] | |
| tempText = inText.lower().split() | |
| articles = ['a', 'an', 'the'] | |
| manualMap = { | |
| 'none': '0', | |
| 'zero': '0', | |
| 'one': '1', | |
| 'two': '2', | |
| 'three': '3', | |
| 'four': '4', | |
| 'five': '5', | |
| 'six': '6', | |
| 'seven': '7', | |
| 'eight': '8', | |
| 'nine': '9', | |
| 'ten': '10', | |
| } | |
| contractions = { | |
| 'aint': "ain't", | |
| 'arent': "aren't", | |
| 'cant': "can't", | |
| 'couldve': "could've", | |
| 'couldnt': "couldn't", | |
| "couldn'tve": "couldn't've", | |
| "couldnt've": "couldn't've", | |
| 'didnt': "didn't", | |
| 'doesnt': "doesn't", | |
| 'dont': "don't", | |
| 'hadnt': "hadn't", | |
| "hadnt've": "hadn't've", | |
| "hadn'tve": "hadn't've", | |
| 'hasnt': "hasn't", | |
| 'havent': "haven't", | |
| 'hed': "he'd", | |
| "hed've": "he'd've", | |
| "he'dve": "he'd've", | |
| 'hes': "he's", | |
| 'howd': "how'd", | |
| 'howll': "how'll", | |
| 'hows': "how's", | |
| "Id've": "I'd've", | |
| "I'dve": "I'd've", | |
| 'Im': "I'm", | |
| 'Ive': "I've", | |
| 'isnt': "isn't", | |
| 'itd': "it'd", | |
| "itd've": "it'd've", | |
| "it'dve": "it'd've", | |
| 'itll': "it'll", | |
| "let's": "let's", | |
| 'maam': "ma'am", | |
| 'mightnt': "mightn't", | |
| "mightnt've": "mightn't've", | |
| "mightn'tve": "mightn't've", | |
| 'mightve': "might've", | |
| 'mustnt': "mustn't", | |
| 'mustve': "must've", | |
| 'neednt': "needn't", | |
| 'notve': "not've", | |
| 'oclock': "o'clock", | |
| 'oughtnt': "oughtn't", | |
| "ow's'at": "'ow's'at", | |
| "'ows'at": "'ow's'at", | |
| "'ow'sat": "'ow's'at", | |
| 'shant': "shan't", | |
| "shed've": "she'd've", | |
| "she'dve": "she'd've", | |
| "she's": "she's", | |
| 'shouldve': "should've", | |
| 'shouldnt': "shouldn't", | |
| "shouldnt've": "shouldn't've", | |
| "shouldn'tve": "shouldn't've", | |
| "somebody'd": 'somebodyd', | |
| "somebodyd've": "somebody'd've", | |
| "somebody'dve": "somebody'd've", | |
| 'somebodyll': "somebody'll", | |
| 'somebodys': "somebody's", | |
| 'someoned': "someone'd", | |
| "someoned've": "someone'd've", | |
| "someone'dve": "someone'd've", | |
| 'someonell': "someone'll", | |
| 'someones': "someone's", | |
| 'somethingd': "something'd", | |
| "somethingd've": "something'd've", | |
| "something'dve": "something'd've", | |
| 'somethingll': "something'll", | |
| 'thats': "that's", | |
| 'thered': "there'd", | |
| "thered've": "there'd've", | |
| "there'dve": "there'd've", | |
| 'therere': "there're", | |
| 'theres': "there's", | |
| 'theyd': "they'd", | |
| "theyd've": "they'd've", | |
| "they'dve": "they'd've", | |
| 'theyll': "they'll", | |
| 'theyre': "they're", | |
| 'theyve': "they've", | |
| 'twas': "'twas", | |
| 'wasnt': "wasn't", | |
| "wed've": "we'd've", | |
| "we'dve": "we'd've", | |
| 'weve': "we've", | |
| 'werent': "weren't", | |
| 'whatll': "what'll", | |
| 'whatre': "what're", | |
| 'whats': "what's", | |
| 'whatve': "what've", | |
| 'whens': "when's", | |
| 'whered': "where'd", | |
| 'wheres': "where's", | |
| 'whereve': "where've", | |
| 'whod': "who'd", | |
| "whod've": "who'd've", | |
| "who'dve": "who'd've", | |
| 'wholl': "who'll", | |
| 'whos': "who's", | |
| 'whove': "who've", | |
| 'whyll': "why'll", | |
| 'whyre': "why're", | |
| 'whys': "why's", | |
| 'wont': "won't", | |
| 'wouldve': "would've", | |
| 'wouldnt': "wouldn't", | |
| "wouldnt've": "wouldn't've", | |
| "wouldn'tve": "wouldn't've", | |
| 'yall': "y'all", | |
| "yall'll": "y'all'll", | |
| "y'allll": "y'all'll", | |
| "yall'd've": "y'all'd've", | |
| "y'alld've": "y'all'd've", | |
| "y'all'dve": "y'all'd've", | |
| 'youd': "you'd", | |
| "youd've": "you'd've", | |
| "you'dve": "you'd've", | |
| 'youll': "you'll", | |
| 'youre': "you're", | |
| 'youve': "you've", | |
| } | |
| for word in tempText: | |
| word = manualMap.setdefault(word, word) | |
| if word not in articles: | |
| outText.append(word) | |
| for wordId, word in enumerate(outText): | |
| if word in contractions: | |
| outText[wordId] = contractions[word] | |
| outText = ' '.join(outText) | |
| return outText | |
| class VQAAcc(BaseMetric): | |
| '''VQA Acc metric. | |
| Args: | |
| collect_device (str): Device name used for collecting results from | |
| different ranks during distributed training. Must be 'cpu' or | |
| 'gpu'. Defaults to 'cpu'. | |
| prefix (str, optional): The prefix that will be added in the metric | |
| names to disambiguate homonymous metrics of different evaluators. | |
| If prefix is not provided in the argument, self.default_prefix | |
| will be used instead. Should be modified according to the | |
| `retrieval_type` for unambiguous results. Defaults to TR. | |
| ''' | |
| default_prefix = 'VQA' | |
| def __init__(self, | |
| full_score_weight: float = 0.3, | |
| collect_device: str = 'cpu', | |
| prefix: Optional[str] = None): | |
| super().__init__(collect_device=collect_device, prefix=prefix) | |
| self.full_score_weight = full_score_weight | |
| def process(self, data_batch, data_samples): | |
| """Process one batch of data samples. | |
| The processed results should be stored in ``self.results``, which will | |
| be used to computed the metrics when all batches have been processed. | |
| Args: | |
| data_batch: A batch of data from the dataloader. | |
| data_samples (Sequence[dict]): A batch of outputs from the model. | |
| """ | |
| for sample in data_samples: | |
| gt_answer = sample.get('gt_answer') | |
| gt_answer_weight = sample.get('gt_answer_weight') | |
| if isinstance(gt_answer, str): | |
| gt_answer = [gt_answer] | |
| if gt_answer_weight is None: | |
| gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer) | |
| result = { | |
| 'pred_answer': sample.get('pred_answer'), | |
| 'gt_answer': gt_answer, | |
| 'gt_answer_weight': gt_answer_weight, | |
| } | |
| self.results.append(result) | |
| def compute_metrics(self, results: List): | |
| """Compute the metrics from processed results. | |
| Args: | |
| results (dict): The processed results of each batch. | |
| Returns: | |
| Dict: The computed metrics. The keys are the names of the metrics, | |
| and the values are corresponding results. | |
| """ | |
| acc = [] | |
| for result in results: | |
| pred_answer = self._process_answer(result['pred_answer']) | |
| gt_answer = [ | |
| self._process_answer(answer) for answer in result['gt_answer'] | |
| ] | |
| answer_weight = result['gt_answer_weight'] | |
| weight_sum = 0 | |
| for i, gt in enumerate(gt_answer): | |
| if gt == pred_answer: | |
| weight_sum += answer_weight[i] | |
| vqa_acc = min(1.0, weight_sum / self.full_score_weight) | |
| acc.append(vqa_acc) | |
| accuracy = sum(acc) / len(acc) * 100 | |
| metrics = {'acc': accuracy} | |
| return metrics | |
| def _process_answer(self, answer): | |
| answer = answer.replace('\n', ' ') | |
| answer = answer.replace('\t', ' ') | |
| answer = answer.strip() | |
| answer = _process_punctuation(answer) | |
| answer = _process_digit_article(answer) | |
| return answer | |
| class ReportVQA(BaseMetric): | |
| """Dump VQA result to the standard json format for VQA evaluation. | |
| Args: | |
| file_path (str): The file path to save the result file. | |
| collect_device (str): Device name used for collecting results from | |
| different ranks during distributed training. Must be 'cpu' or | |
| 'gpu'. Defaults to 'cpu'. | |
| prefix (str, optional): The prefix that will be added in the metric | |
| names to disambiguate homonymous metrics of different evaluators. | |
| If prefix is not provided in the argument, self.default_prefix | |
| will be used instead. Should be modified according to the | |
| `retrieval_type` for unambiguous results. Defaults to TR. | |
| """ | |
| default_prefix = 'VQA' | |
| def __init__(self, | |
| file_path: str, | |
| collect_device: str = 'cpu', | |
| prefix: Optional[str] = None): | |
| super().__init__(collect_device=collect_device, prefix=prefix) | |
| if not file_path.endswith('.json'): | |
| raise ValueError('The output file must be a json file.') | |
| self.file_path = file_path | |
| def process(self, data_batch, data_samples) -> None: | |
| """transfer tensors in predictions to CPU.""" | |
| for sample in data_samples: | |
| question_id = sample['question_id'] | |
| pred_answer = sample['pred_answer'] | |
| result = { | |
| 'question_id': int(question_id), | |
| 'answer': pred_answer, | |
| } | |
| self.results.append(result) | |
| def compute_metrics(self, results: List): | |
| """Dump the result to json file.""" | |
| mmengine.dump(results, self.file_path) | |
| logger = MMLogger.get_current_instance() | |
| logger.info(f'Results has been saved to {self.file_path}.') | |
| return {} | |