[Fix]: Fix multi-head config

pull/352/head
YuanLiuuuuuu 2022-07-06 08:18:46 +00:00 committed by fangyixiao18
parent 1e16016b27
commit 8910743c6e
5 changed files with 115 additions and 64 deletions

View File

@ -1,5 +1,10 @@
model = dict(
type='Classification',
type='ImageClassifier',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
),
backbone=dict(
type='ResNet',
depth=50,
@ -8,7 +13,7 @@ model = dict(
norm_cfg=dict(type='BN'),
frozen_stages=-1),
head=dict(
type='MultiClsHead',
type='mmselfsup.MultiClsHead',
pool_type='specified',
in_indices=[0, 1, 2, 3, 4],
with_last_layer_unpool=False,

View File

@ -4,53 +4,77 @@ _base_ = [
'../_base_/schedules/sgd_steplr-100e.py',
'../_base_/default_runtime.py',
]
# Multi-head linear evaluation setting
model = dict(backbone=dict(frozen_stages=4))
# lighting params, in order of BGR
EIGVAL = [55.4625, 4.7940, 1.1475]
EIGVEC = [
[-0.5836, -0.6948, 0.4203],
[-0.5808, -0.0045, -0.8140],
[-0.5675, 0.7192, 0.4009],
]
# dataset settings
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# dataset
train_pipeline = [
dict(type='RandomResizedCrop', size=224),
dict(type='RandomHorizontalFlip'),
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.),
dict(type='ToTensor'),
dict(type='Lighting'),
dict(type='Normalize', **img_norm_cfg),
dict(
type='Lighting',
eigval=EIGVAL,
eigvec=EIGVEC,
),
dict(type='PackClsInputs')
]
test_pipeline = [
dict(type='Resize', size=256),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
]
data = dict(
train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline))
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# MoCo v1/v2 linear evaluation setting
model = dict(
backbone=dict(out_indices=[0, 1, 2, 3], frozen_stages=4),
head=dict(in_indices=[1, 2, 3, 4]))
# optimizer
optimizer = dict(
type='SGD',
lr=0.01,
momentum=0.9,
weight_decay=1e-4,
paramwise_options=dict(norm_decay_mult=0.),
nesterov=True)
# learning rate scheduler
param_scheduler = [
dict(
type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)
]
type='SGD', lr=0.01, momentum=0.9, weight_decay=1e-4, nesterov=True)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
paramwise_cfg=dict(norm_decay_mult=0.0))
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=90)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
# evaluator
val_evaluator = dict(
_delete_=True,
type='MultiHeadEvaluator',
metrics=dict(
head1=dict(type='mmcls.Accuracy', topk=(1, 5)),
head2=dict(type='mmcls.Accuracy', topk=(1, 5)),
head3=dict(type='mmcls.Accuracy', topk=(1, 5)),
head4=dict(type='mmcls.Accuracy', topk=(1, 5))))
# epochs
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=90)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union
from typing import Dict, List, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -22,20 +22,29 @@ class MultiClsHead(ClsHead):
linear classifier at each stage to predict corresponding class scores.
Args:
backbone (str, optional): Specify which backbone to use, only support
backbone (str): Specify which backbone to use, only support
ResNet50. Defaults to 'resnet50'.
in_indices (Sequence[int], optional): Input from which stages.
in_indices (Sequence[int]): Input from which stages.
Defaults to (0, 1, 2, 3, 4).
pool_type (str, optional): 'adaptive' or 'specified'. If set to
pool_type (str): 'adaptive' or 'specified'. If set to
'adaptive', use adaptive average pooling, otherwise use specified
pooling params. Defaults to 'adaptive'.
num_classes (int, optional): Number of classes. Defaults to 1000.
loss (Dict, optional): The Dict of loss information. Defaults to
'mmcls.models.CrossEntropyLoss'
with_last_layer_unpool (bool, optional): Whether to unpool the features
num_classes (int): Number of classes. Defaults to 1000.
loss (dict): The Dict of loss information. Defaults to
'mmcls.models.CrossEntro): Whether to unpool the features
from last layer. Defaults to False.
norm_cfg (Dict, optional): Dict to construct and config norm layer.
init_cfg (Dict or List[Dict], optional): Initialization config dict.
cal_acc (bool): Whether to calculate accuracy during training.
If you use batch augmentations like Mixup and CutMix during
training, it is pointless to calculate accuracy.
Defaults to False.
topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
norm_cfg (dict): Dict to construct and config norm layer.
Defaults to ``dict(type='BN')``.
init_cfg (dict or List[dict]): Initialization config dict.
Defaults to ``[
dict(type='Normal', std=0.01, layer='Linear'),
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]``
"""
FEAT_CHANNELS = {'resnet50': [64, 256, 512, 1024, 2048]}
@ -43,15 +52,16 @@ class MultiClsHead(ClsHead):
def __init__(
self,
backbone: Optional[str] = 'resnet50',
in_indices: Optional[Sequence[int]] = (0, 1, 2, 3, 4),
pool_type: Optional[str] = 'adaptive',
num_classes: Optional[int] = 1000,
loss: Optional[Dict] = dict(
type='mmcls.CrossEntropyLoss', loss_weight=1.0),
with_last_layer_unpool: Optional[bool] = False,
norm_cfg: Optional[Dict] = dict(type='BN'),
init_cfg: Optional[Union[Dict, List[Dict]]] = [
backbone: str = 'resnet50',
in_indices: Sequence[int] = (0, 1, 2, 3, 4),
pool_type: str = 'adaptive',
num_classes: int = 1000,
loss: dict = dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
with_last_layer_unpool: bool = False,
cal_acc: bool = False,
topk: Union[int, Tuple[int]] = (1, ),
norm_cfg: dict = dict(type='BN'),
init_cfg: Union[Dict, List[Dict]] = [
dict(type='Normal', std=0.01, layer='Linear'),
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
@ -80,10 +90,12 @@ class MultiClsHead(ClsHead):
for i in in_indices
])
self.cal_acc = cal_acc
self.topk = topk
# build loss
self.loss_module = MODELS.build(loss)
def forward(self, feats):
def forward(self, feats: Union[list, tuple]) -> list:
"""Compute multi-head scores.
Args:
@ -96,6 +108,7 @@ class MultiClsHead(ClsHead):
assert isinstance(feats, (list, tuple))
if self.with_last_layer_unpool:
last_feats = feats[-1]
feats = self.multi_pooling(feats)
if self.with_norm:
@ -109,7 +122,7 @@ class MultiClsHead(ClsHead):
return cls_score
def loss(self, feats: Sequence[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> Dict:
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the extracted features.
Args:
@ -131,11 +144,17 @@ class MultiClsHead(ClsHead):
target = torch.hstack(
[data_sample.gt_label.label for data_sample in data_samples])
# compute loss
# compute loss and accuracy
losses = dict()
for i, score in zip(self.in_indices, cls_score):
losses[f'loss.{i + 1}'] = self.loss_module(score, target)
losses[f'accuracy.{i + 1}'] = Accuracy.calculate(score, target)
if self.cal_acc:
acc = Accuracy.calculate(score, target, topk=self.topk)
losses.update({
f'accuracy.{i+1}.top-{k}': a
for k, a in zip(self.topk, acc)
})
return losses
def predict(self, feats: Sequence[torch.Tensor],

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
import torch.nn as nn
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
class MultiPooling(BaseModule):
@ -27,12 +29,13 @@ class MultiPooling(BaseModule):
POOL_DIMS = {'resnet50': [9216, 9216, 8192, 9216, 8192]}
def __init__(self,
pool_type='adaptive',
in_indices=(0, ),
backbone='resnet50'):
super(MultiPooling, self).__init__()
pool_type: str = 'adaptive',
in_indices: tuple = (0, ),
backbone: str = 'resnet50') -> None:
super().__init__()
assert pool_type in ['adaptive', 'specified']
assert backbone == 'resnet50', 'Now only support resnet50.'
if pool_type == 'adaptive':
self.pools = nn.ModuleList([
nn.AdaptiveAvgPool2d(self.POOL_SIZES[backbone][i])
@ -44,6 +47,6 @@ class MultiPooling(BaseModule):
for i in in_indices
])
def forward(self, x):
def forward(self, x: Union[List, Tuple]) -> None:
assert isinstance(x, (list, tuple))
return [p(xx) for p, xx in zip(self.pools, x)]

View File

@ -18,7 +18,7 @@ class TestMultiClsHead(TestCase):
fake_data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
losses = head.loss(fake_in, fake_data_samples)
print(losses)
self.assertEqual(len(losses.keys()), 4)
self.assertEqual(len(losses.keys()), 2)
for k in losses.keys():
assert k.startswith('loss') or k.startswith('accuracy')
if k.startswith('loss'):