diff --git a/configs/_base_/datasets/voc_bs16.py b/configs/_base_/datasets/voc_bs16.py new file mode 100644 index 000000000..73fa0bcc8 --- /dev/null +++ b/configs/_base_/datasets/voc_bs16.py @@ -0,0 +1,41 @@ +# dataset settings +dataset_type = 'VOC' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomResizedCrop', size=224), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=(256, -1)), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_prefix='data/VOCdevkit/VOC2007/', + ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_prefix='data/VOCdevkit/VOC2007/', + ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_prefix='data/VOCdevkit/VOC2007/', + ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt', + pipeline=test_pipeline)) +evaluation = dict( + interval=1, metric=['mAP', 'CP', 'OP', 'CR', 'OR', 'CF1', 'OF1']) diff --git a/configs/vgg/vgg16_voc.py b/configs/vgg/vgg16_voc.py new file mode 100644 index 000000000..e6dd70bb4 --- /dev/null +++ b/configs/vgg/vgg16_voc.py @@ -0,0 +1,25 @@ +_base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py'] + +# use different head for multilabel task +model = dict( + type='ImageClassifier', + backbone=dict(type='VGG', depth=16, num_classes=20), + neck=None, + head=dict( + type='MultiLabelClsHead', + loss=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))) + +# load model pretrained on imagenet +load_from = 'https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_imagenet-91b6d117.pth' # noqa + +# optimizer +optimizer = dict( + type='SGD', + lr=0.001, + momentum=0.9, + weight_decay=0, + paramwise_cfg=dict(custom_keys={'.backbone.classifier': dict(lr_mult=10)})) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=20, gamma=0.1) +runner = dict(type='EpochBasedRunner', max_epochs=40) diff --git a/mmcls/core/evaluation/mean_ap.py b/mmcls/core/evaluation/mean_ap.py index 06d8e589f..44331caba 100644 --- a/mmcls/core/evaluation/mean_ap.py +++ b/mmcls/core/evaluation/mean_ap.py @@ -57,8 +57,8 @@ def mAP(pred, target): float: A single float as mAP value. """ if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): - pred = pred.numpy() - target = target.numpy() + pred = pred.detach().cpu().numpy() + target = target.detach().cpu().numpy() elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)): raise TypeError('pred and target should both be torch.Tensor or' 'np.ndarray') diff --git a/mmcls/core/evaluation/multilabel_eval_metrics.py b/mmcls/core/evaluation/multilabel_eval_metrics.py index 6cb9341a4..c8663a16b 100644 --- a/mmcls/core/evaluation/multilabel_eval_metrics.py +++ b/mmcls/core/evaluation/multilabel_eval_metrics.py @@ -24,8 +24,8 @@ def average_performance(pred, target, thr=None, k=None): tuple: (CP, CR, CF1, OP, OR, OF1) """ if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): - pred = pred.numpy() - target = target.numpy() + pred = pred.detach().cpu().numpy() + target = target.detach().cpu().numpy() elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)): raise TypeError('pred and target should both be torch.Tensor or' 'np.ndarray') diff --git a/mmcls/datasets/multi_label.py b/mmcls/datasets/multi_label.py index eff051a38..274a6ad96 100644 --- a/mmcls/datasets/multi_label.py +++ b/mmcls/datasets/multi_label.py @@ -48,13 +48,12 @@ class MultiLabelDataset(BaseDataset): invalid_metrics = set(metrics) - set(allowed_metrics) if len(invalid_metrics) != 0: - raise KeyError(f'metirc {invalid_metrics} is not supported.') + raise ValueError(f'metirc {invalid_metrics} is not supported.') if 'mAP' in metrics: mAP_value = mAP(results, gt_labels) eval_results['mAP'] = mAP_value - metrics.remove('mAP') - if len(metrics) != 0: + if len(set(metrics) - {'mAP'}) != 0: performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1'] performance_values = average_performance(results, gt_labels, **eval_kwargs) diff --git a/mmcls/models/classifiers/image.py b/mmcls/models/classifiers/image.py index 436f2ebf4..d13d4a49c 100644 --- a/mmcls/models/classifiers/image.py +++ b/mmcls/models/classifiers/image.py @@ -46,8 +46,10 @@ class ImageClassifier(BaseClassifier): img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. - gt_label (Tensor): of shape (N, 1) encoding the ground-truth label - of input images. + gt_label (Tensor): It should be of shape (N, 1) encoding the + ground-truth label of input images for single label task. It + shoulf be of shape (N, C) encoding the ground-truth label + of input images for multi-labels task. Returns: dict[str, Tensor]: a dictionary of loss components diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py index 9c6dc349c..361802f4d 100644 --- a/mmcls/models/heads/__init__.py +++ b/mmcls/models/heads/__init__.py @@ -1,4 +1,8 @@ from .cls_head import ClsHead from .linear_head import LinearClsHead +from .multi_label_head import MultiLabelClsHead +from .multi_label_linear_head import MultiLabelLinearClsHead -__all__ = ['ClsHead', 'LinearClsHead'] +__all__ = [ + 'ClsHead', 'LinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead' +] diff --git a/mmcls/models/heads/multi_label_head.py b/mmcls/models/heads/multi_label_head.py new file mode 100644 index 000000000..e3109638c --- /dev/null +++ b/mmcls/models/heads/multi_label_head.py @@ -0,0 +1,53 @@ +import torch +import torch.nn.functional as F + +from ..builder import HEADS, build_loss +from .base_head import BaseHead + + +@HEADS.register_module() +class MultiLabelClsHead(BaseHead): + """Classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. + + """ + + def __init__(self, + loss=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=1.0)): + super(MultiLabelClsHead, self).__init__() + + assert isinstance(loss, dict) + + self.compute_loss = build_loss(loss) + + def loss(self, cls_score, gt_label): + gt_label = gt_label.type_as(cls_score) + num_samples = len(cls_score) + losses = dict() + + # map difficult examples to positive ones + _gt_label = torch.abs(gt_label) + # compute loss + loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples) + losses['loss'] = loss + return losses + + def forward_train(self, cls_score, gt_label): + gt_label = gt_label.type_as(cls_score) + losses = self.loss(cls_score, gt_label) + return losses + + def simple_test(self, cls_score): + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.sigmoid(cls_score) if cls_score is not None else None + if torch.onnx.is_in_onnx_export(): + return pred + pred = list(pred.detach().cpu().numpy()) + return pred diff --git a/mmcls/models/heads/multi_label_linear_head.py b/mmcls/models/heads/multi_label_linear_head.py new file mode 100644 index 000000000..b2507d0e0 --- /dev/null +++ b/mmcls/models/heads/multi_label_linear_head.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import normal_init + +from ..builder import HEADS +from .multi_label_head import MultiLabelClsHead + + +@HEADS.register_module() +class MultiLabelLinearClsHead(MultiLabelClsHead): + """Linear classification head for multilabel task. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + loss (dict): Config of classification loss. + + """ + + def __init__(self, + num_classes, + in_channels, + loss=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=1.0)): + super(MultiLabelLinearClsHead, self).__init__(loss=loss) + + if num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.in_channels = in_channels + self.num_classes = num_classes + self._init_layers() + + def _init_layers(self): + self.fc = nn.Linear(self.in_channels, self.num_classes) + + def init_weights(self): + normal_init(self.fc, mean=0, std=0.01, bias=0) + + def forward_train(self, x, gt_label): + gt_label = gt_label.type_as(x) + cls_score = self.fc(x) + losses = self.loss(cls_score, gt_label) + return losses + + def simple_test(self, img): + """Test without augmentation.""" + cls_score = self.fc(img) + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.sigmoid(cls_score) if cls_score is not None else None + if torch.onnx.is_in_onnx_export(): + return pred + pred = list(pred.detach().cpu().numpy()) + return pred diff --git a/mmcls/models/losses/cross_entropy_loss.py b/mmcls/models/losses/cross_entropy_loss.py index 541bfc99c..ab9b264d0 100644 --- a/mmcls/models/losses/cross_entropy_loss.py +++ b/mmcls/models/losses/cross_entropy_loss.py @@ -11,7 +11,7 @@ def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None): Args: pred (torch.Tensor): The prediction with shape (N, C), C is the number of classes. - label (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): The gt label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. reduction (str): The method used to reduce the loss. avg_factor (int, optional): Average factor that is used to average @@ -41,7 +41,7 @@ def binary_cross_entropy(pred, Args: pred (torch.Tensor): The prediction with shape (N, *). - label (torch.Tensor): The learning label with shape (N, *). + label (torch.Tensor): The gt label with shape (N, *). weight (torch.Tensor, optional): Element-wise weight of loss with shape (N, ). Defaults to None. reduction (str): The method used to reduce the loss. diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 8c1a7e2bd..4492f3ad1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -242,7 +242,7 @@ def test_dataset_evaluation(): [0.8, 0.1, 0.1, 0.2]]) # the metric must be valid - with pytest.raises(KeyError): + with pytest.raises(ValueError): metric = 'coverage' dataset.evaluate(fake_results, metric=metric) # only one metric diff --git a/tests/test_heads.py b/tests/test_heads.py new file mode 100644 index 000000000..f3a705263 --- /dev/null +++ b/tests/test_heads.py @@ -0,0 +1,22 @@ +import torch + +from mmcls.models.heads import MultiLabelClsHead, MultiLabelLinearClsHead + + +def test_multilabel_head(): + head = MultiLabelClsHead() + fake_cls_score = torch.rand(4, 3) + fake_gt_label = torch.randint(0, 2, (4, 3)) + + losses = head.loss(fake_cls_score, fake_gt_label) + assert losses['loss'].item() > 0 + + +def test_multilabel_linear_head(): + head = MultiLabelLinearClsHead(3, 5) + fake_cls_score = torch.rand(4, 3) + fake_gt_label = torch.randint(0, 2, (4, 3)) + + head.init_weights() + losses = head.loss(fake_cls_score, fake_gt_label) + assert losses['loss'].item() > 0