mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
121 lines
4.7 KiB
Python
121 lines
4.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, Sequence
|
|
|
|
from mmengine.evaluator import BaseMetric
|
|
|
|
from mmcls.registry import METRICS
|
|
|
|
|
|
@METRICS.register_module()
|
|
class MultiTasksMetric(BaseMetric):
|
|
"""Metrics for MultiTask
|
|
Args:
|
|
task_metrics(dict): a dictionary in the keys are the names of the tasks
|
|
and the values is a list of the metric corresponds to this task
|
|
Examples:
|
|
>>> import torch
|
|
>>> from mmcls.evaluation import MultiTasksMetric
|
|
# -------------------- The Basic Usage --------------------
|
|
>>>task_metrics = {
|
|
'task0': [dict(type='Accuracy', topk=(1, ))],
|
|
'task1': [dict(type='Accuracy', topk=(1, 3))]
|
|
}
|
|
>>>pred = [{
|
|
'pred_task': {
|
|
'task0': torch.tensor([0.7, 0.0, 0.3]),
|
|
'task1': torch.tensor([0.5, 0.2, 0.3])
|
|
},
|
|
'gt_task': {
|
|
'task0': torch.tensor(0),
|
|
'task1': torch.tensor(2)
|
|
}
|
|
}, {
|
|
'pred_task': {
|
|
'task0': torch.tensor([0.0, 0.0, 1.0]),
|
|
'task1': torch.tensor([0.0, 0.0, 1.0])
|
|
},
|
|
'gt_task': {
|
|
'task0': torch.tensor(2),
|
|
'task1': torch.tensor(2)
|
|
}
|
|
}]
|
|
>>>metric = MultiTasksMetric(task_metrics)
|
|
>>>metric.process(None, pred)
|
|
>>>results = metric.evaluate(2)
|
|
results = {
|
|
'task0_accuracy/top1': 100.0,
|
|
'task1_accuracy/top1': 50.0,
|
|
'task1_accuracy/top3': 100.0
|
|
}
|
|
"""
|
|
|
|
def __init__(self,
|
|
task_metrics: Dict,
|
|
collect_device: str = 'cpu') -> None:
|
|
self.task_metrics = task_metrics
|
|
super().__init__(collect_device=collect_device)
|
|
|
|
self._metrics = {}
|
|
for task_name in self.task_metrics.keys():
|
|
self._metrics[task_name] = []
|
|
for metric in self.task_metrics[task_name]:
|
|
self._metrics[task_name].append(METRICS.build(metric))
|
|
|
|
def process(self, data_batch, data_samples: Sequence[dict]):
|
|
"""Process one batch of data samples.
|
|
|
|
The processed results should be stored in ``self.results``, which will
|
|
be used to computed the metrics when all batches have been processed.
|
|
Args:
|
|
data_batch: A batch of data from the dataloader.
|
|
data_samples (Sequence[dict]): A batch of outputs from the model.
|
|
"""
|
|
for task_name in self.task_metrics.keys():
|
|
filtered_data_samples = []
|
|
for data_sample in data_samples:
|
|
eval_mask = data_sample[task_name]['eval_mask']
|
|
if eval_mask:
|
|
filtered_data_samples.append(data_sample[task_name])
|
|
for metric in self._metrics[task_name]:
|
|
metric.process(data_batch, filtered_data_samples)
|
|
|
|
def compute_metrics(self, results: list) -> dict:
|
|
raise NotImplementedError(
|
|
'compute metrics should not be used here directly')
|
|
|
|
def evaluate(self, size):
|
|
"""Evaluate the model performance of the whole dataset after processing
|
|
all batches.
|
|
|
|
Args:
|
|
size (int): Length of the entire validation dataset. When batch
|
|
size > 1, the dataloader may pad some data samples to make
|
|
sure all ranks have the same length of dataset slice. The
|
|
``collect_results`` function will drop the padded data based on
|
|
this size.
|
|
Returns:
|
|
dict: Evaluation metrics dict on the val dataset. The keys are
|
|
"{task_name}_{metric_name}" , and the values
|
|
are corresponding results.
|
|
"""
|
|
metrics = {}
|
|
for task_name in self._metrics:
|
|
for metric in self._metrics[task_name]:
|
|
name = metric.__class__.__name__
|
|
if name == 'MultiTasksMetric' or metric.results:
|
|
results = metric.evaluate(size)
|
|
else:
|
|
results = {metric.__class__.__name__: 0}
|
|
for key in results:
|
|
name = f'{task_name}_{key}'
|
|
if name in results:
|
|
"""Inspired from https://github.com/open-
|
|
mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2
|
|
747172e/mmengine/evalua tor/evaluator.py#L84-L87."""
|
|
raise ValueError(
|
|
'There are multiple metric results with the same'
|
|
f'metric name {name}. Please make sure all metrics'
|
|
'have different prefixes.')
|
|
metrics[name] = results[key]
|
|
return metrics
|