mmpretrain/mmcls/models/heads/multi_task_head.py
Colle bac181f393
[Feature] Support Multi-task. (#1229)
* unit test for multi_task_head

* [Feature] MultiTaskHead (#628, #481)

* [Fix] lint for multi_task_head

* [Feature] Add `MultiTaskDataset` to support multi-task training.

* Update MultiTaskClsHead

* Update docs

* [CI] Add test mim CI. (#879)

* [Fix] Remove duplicated wide-resnet metafile.

* [Feature] Support MPS device. (#894)

* [Feature] Support MPS device.

* Add `auto_select_device`

* Add unit tests

* [Fix] Fix Albu crash bug. (#918)

* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning

* Fix common

* Using copy incase potential bug in multi-label tasks

* Improve coding

* Improve code logic

* Add unit test

* Fix typo

* Fix yapf

* Bump version to 0.23.2. (#937)

* [Improve] Use `forward_dummy` to calculate FLOPS. (#953)

* Update README

* [Docs] Fix typo for wrong reference. (#1036)

* [Doc] Fix typo in tutorial 2 (#1043)

* [Docs] Fix a typo in ImageClassifier (#1050)

* add mask to loss

* add another pipeline

* adpat the pipeline if there is no mask

* switch mask and task

* first version of multi data smaple

* fix problem with attribut by getattr

* rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label'

* training  without evaluation

* first version work

* add others metrics

* delete evaluation from dataset

* fix linter

* fix linter

* multi metrics

* first version of test

* change evaluate metric

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* add tests

* add test for multidatasample

* create a generic test

* create a generic test

* create a generic test

* change multi data sample

* correct test

* test

* add new test

* add test for dataset

* correct test

* correct test

* correct test

* correct test

* fix : #5

* run yapf

* fix linter

* fix linter

* fix linter

* fix isort

* fix isort

* fix docformmater

* fix docformmater

* fix linter

* fix linter

* fix data sample

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update data sample

* update head

* update head

* update multi data sample

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* update head

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix linter

* fix : #2

* fix : linter

* update multi head

* fix linter

* fix linter

* update data sample

* update data sample

* fix ; linter

* update test

* test pipeline

* update pipeline

* update test

* update dataset

* update dataset

* fix linter

* fix linter

* update formatting

* add test for multi-task-eval

* update formatting

* fix linter

* update test

* update

* add test

* update metrics

* update metrics

* add doc for functions

* fix linter

* training for multitask 1.x

* fix linter

* run flake8

* run linter

* update test

* add mask in evaluation

* update metric doc

* update metric doc

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update metric doc

* update metric doc

* Fix cannot import name MultiTaskDataSample

* fix test_datasets

* fix test_datasets

* fix linter

* add an example of multitask

* change name of configs dataset

* Refactor the multi-task support

* correct test and metric

* add test to multidatasample

* add test to multidatasample

* correct test

* correct metrics and clshead

* Update mmcls/models/heads/cls_head.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update cls_head.py documentation

* lint

* lint

* fix: lint

* fix linter

* add eval mask

* fix documentation

* fix: single_label.py back to 1.x

* Update mmcls/models/heads/multi_task_head.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Remove multi-task configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com>
Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: marouaneamz <maroineamil99@gmail.com>
Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00

140 lines
5.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
from mmengine.model import ModuleDict
from mmcls.registry import MODELS
from mmcls.structures import MultiTaskDataSample
from .base_head import BaseHead
def loss_convertor(loss_func, task_name):
def wrapped(inputs, data_samples, **kwargs):
mask = torch.empty(len(data_samples), dtype=torch.bool)
task_data_samples = []
for i, data_sample in enumerate(data_samples):
assert isinstance(data_sample, MultiTaskDataSample)
sample_mask = task_name in data_sample
mask[i] = sample_mask
if sample_mask:
task_data_samples.append(data_sample.get(task_name))
if len(task_data_samples) == 0:
return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)}
# Mask the inputs of the task
def mask_inputs(inputs, mask):
if isinstance(inputs, Sequence):
return type(inputs)(
[mask_inputs(input, mask) for input in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs[mask]
masked_inputs = mask_inputs(inputs, mask)
loss_output = loss_func(masked_inputs, task_data_samples, **kwargs)
loss_output['mask_size'] = mask.sum().to(torch.float)
return loss_output
return wrapped
@MODELS.register_module()
class MultiTaskHead(BaseHead):
"""Multi task head.
Args:
task_heads (dict): Sub heads to use, the key will be use to rename the
loss components.
common_cfg (dict): The common settings for all heads. Defaults to an
empty dict.
init_cfg (dict, optional): The extra initialization settings.
Defaults to None.
"""
def __init__(self, task_heads, init_cfg=None, **kwargs):
super(MultiTaskHead, self).__init__(init_cfg=init_cfg)
assert isinstance(task_heads, dict), 'The `task_heads` argument' \
"should be a dict, which's keys are task names and values are" \
'configs of head for the task.'
self.task_heads = ModuleDict()
for task_name, sub_head in task_heads.items():
if not isinstance(sub_head, nn.Module):
sub_head = MODELS.build(sub_head, default_args=kwargs)
sub_head.loss = loss_convertor(sub_head.loss, task_name)
self.task_heads[task_name] = sub_head
def forward(self, feats):
"""The forward process."""
return {
task_name: head(feats)
for task_name, head in self.task_heads.items()
}
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[MultiTaskDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[MultiTaskDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components, each task loss
key will be prefixed by the task_name like "task1_loss"
"""
losses = dict()
for task_name, head in self.task_heads.items():
head_loss = head.loss(feats, data_samples, **kwargs)
for k, v in head_loss.items():
losses[f'{task_name}_{k}'] = v
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[MultiTaskDataSample] = None
) -> List[MultiTaskDataSample]:
"""Inference without augmentation.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[MultiTaskDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[MultiTaskDataSample]: A list of data samples which contains
the predicted results.
"""
predictions_dict = dict()
for task_name, head in self.task_heads.items():
task_samples = head.predict(feats)
batch_size = len(task_samples)
predictions_dict[task_name] = task_samples
if data_samples is None:
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
for task_name, task_samples in predictions_dict.items():
for data_sample, task_sample in zip(data_samples, task_samples):
task_sample.set_field(
task_name in data_sample.tasks,
'eval_mask',
field_type='metainfo')
if task_name in data_sample.tasks:
data_sample.get(task_name).update(task_sample)
else:
data_sample.set_field(task_sample, task_name)
return data_samples