[Feature] Add heads and config for multilabel task (#145)
* resolve conflicts add heads and config for multilabel tasks * minor change * remove evaluating mAP in head * add baseline config * add configs * reserve only one config * minor change * fix minor bug * minor change * minor change * add unittests and fix docstringspull/151/head
parent
13c1210741
commit
07bb15e5fd
|
@ -0,0 +1,41 @@
|
|||
# dataset settings
|
||||
dataset_type = 'VOC'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=(256, -1)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/VOCdevkit/VOC2007/',
|
||||
ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/VOCdevkit/VOC2007/',
|
||||
ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/VOCdevkit/VOC2007/',
|
||||
ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt',
|
||||
pipeline=test_pipeline))
|
||||
evaluation = dict(
|
||||
interval=1, metric=['mAP', 'CP', 'OP', 'CR', 'OR', 'CF1', 'OF1'])
|
|
@ -0,0 +1,25 @@
|
|||
_base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']
|
||||
|
||||
# use different head for multilabel task
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='VGG', depth=16, num_classes=20),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='MultiLabelClsHead',
|
||||
loss=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
|
||||
|
||||
# load model pretrained on imagenet
|
||||
load_from = 'https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_imagenet-91b6d117.pth' # noqa
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD',
|
||||
lr=0.001,
|
||||
momentum=0.9,
|
||||
weight_decay=0,
|
||||
paramwise_cfg=dict(custom_keys={'.backbone.classifier': dict(lr_mult=10)}))
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=20, gamma=0.1)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=40)
|
|
@ -57,8 +57,8 @@ def mAP(pred, target):
|
|||
float: A single float as mAP value.
|
||||
"""
|
||||
if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
|
||||
pred = pred.numpy()
|
||||
target = target.numpy()
|
||||
pred = pred.detach().cpu().numpy()
|
||||
target = target.detach().cpu().numpy()
|
||||
elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)):
|
||||
raise TypeError('pred and target should both be torch.Tensor or'
|
||||
'np.ndarray')
|
||||
|
|
|
@ -24,8 +24,8 @@ def average_performance(pred, target, thr=None, k=None):
|
|||
tuple: (CP, CR, CF1, OP, OR, OF1)
|
||||
"""
|
||||
if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
|
||||
pred = pred.numpy()
|
||||
target = target.numpy()
|
||||
pred = pred.detach().cpu().numpy()
|
||||
target = target.detach().cpu().numpy()
|
||||
elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)):
|
||||
raise TypeError('pred and target should both be torch.Tensor or'
|
||||
'np.ndarray')
|
||||
|
|
|
@ -48,13 +48,12 @@ class MultiLabelDataset(BaseDataset):
|
|||
|
||||
invalid_metrics = set(metrics) - set(allowed_metrics)
|
||||
if len(invalid_metrics) != 0:
|
||||
raise KeyError(f'metirc {invalid_metrics} is not supported.')
|
||||
raise ValueError(f'metirc {invalid_metrics} is not supported.')
|
||||
|
||||
if 'mAP' in metrics:
|
||||
mAP_value = mAP(results, gt_labels)
|
||||
eval_results['mAP'] = mAP_value
|
||||
metrics.remove('mAP')
|
||||
if len(metrics) != 0:
|
||||
if len(set(metrics) - {'mAP'}) != 0:
|
||||
performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
|
||||
performance_values = average_performance(results, gt_labels,
|
||||
**eval_kwargs)
|
||||
|
|
|
@ -46,8 +46,10 @@ class ImageClassifier(BaseClassifier):
|
|||
img (Tensor): of shape (N, C, H, W) encoding input images.
|
||||
Typically these should be mean centered and std scaled.
|
||||
|
||||
gt_label (Tensor): of shape (N, 1) encoding the ground-truth label
|
||||
of input images.
|
||||
gt_label (Tensor): It should be of shape (N, 1) encoding the
|
||||
ground-truth label of input images for single label task. It
|
||||
shoulf be of shape (N, C) encoding the ground-truth label
|
||||
of input images for multi-labels task.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from .cls_head import ClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
|
||||
__all__ = ['ClsHead', 'LinearClsHead']
|
||||
__all__ = [
|
||||
'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import HEADS, build_loss
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class MultiLabelClsHead(BaseHead):
|
||||
"""Classification head for multilabel task.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of classification loss.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)):
|
||||
super(MultiLabelClsHead, self).__init__()
|
||||
|
||||
assert isinstance(loss, dict)
|
||||
|
||||
self.compute_loss = build_loss(loss)
|
||||
|
||||
def loss(self, cls_score, gt_label):
|
||||
gt_label = gt_label.type_as(cls_score)
|
||||
num_samples = len(cls_score)
|
||||
losses = dict()
|
||||
|
||||
# map difficult examples to positive ones
|
||||
_gt_label = torch.abs(gt_label)
|
||||
# compute loss
|
||||
loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples)
|
||||
losses['loss'] = loss
|
||||
return losses
|
||||
|
||||
def forward_train(self, cls_score, gt_label):
|
||||
gt_label = gt_label.type_as(cls_score)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
||||
|
||||
def simple_test(self, cls_score):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
|
@ -0,0 +1,60 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from ..builder import HEADS
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class MultiLabelLinearClsHead(MultiLabelClsHead):
|
||||
"""Linear classification head for multilabel task.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
loss (dict): Config of classification loss.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
loss=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)):
|
||||
super(MultiLabelLinearClsHead, self).__init__(loss=loss)
|
||||
|
||||
if num_classes <= 0:
|
||||
raise ValueError(
|
||||
f'num_classes={num_classes} must be a positive integer')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.fc, mean=0, std=0.01, bias=0)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
gt_label = gt_label.type_as(x)
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
||||
|
||||
def simple_test(self, img):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.fc(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
|
@ -11,7 +11,7 @@ def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
|
|||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
label (torch.Tensor): The gt label of the prediction.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
reduction (str): The method used to reduce the loss.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
|
@ -41,7 +41,7 @@ def binary_cross_entropy(pred,
|
|||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, *).
|
||||
label (torch.Tensor): The learning label with shape (N, *).
|
||||
label (torch.Tensor): The gt label with shape (N, *).
|
||||
weight (torch.Tensor, optional): Element-wise weight of loss with shape
|
||||
(N, ). Defaults to None.
|
||||
reduction (str): The method used to reduce the loss.
|
||||
|
|
|
@ -242,7 +242,7 @@ def test_dataset_evaluation():
|
|||
[0.8, 0.1, 0.1, 0.2]])
|
||||
|
||||
# the metric must be valid
|
||||
with pytest.raises(KeyError):
|
||||
with pytest.raises(ValueError):
|
||||
metric = 'coverage'
|
||||
dataset.evaluate(fake_results, metric=metric)
|
||||
# only one metric
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
|
||||
from mmcls.models.heads import MultiLabelClsHead, MultiLabelLinearClsHead
|
||||
|
||||
|
||||
def test_multilabel_head():
|
||||
head = MultiLabelClsHead()
|
||||
fake_cls_score = torch.rand(4, 3)
|
||||
fake_gt_label = torch.randint(0, 2, (4, 3))
|
||||
|
||||
losses = head.loss(fake_cls_score, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_multilabel_linear_head():
|
||||
head = MultiLabelLinearClsHead(3, 5)
|
||||
fake_cls_score = torch.rand(4, 3)
|
||||
fake_gt_label = torch.randint(0, 2, (4, 3))
|
||||
|
||||
head.init_weights()
|
||||
losses = head.loss(fake_cls_score, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
Loading…
Reference in New Issue