[Fix]: Fix multi-head config
parent
1e16016b27
commit
8910743c6e
|
@ -1,5 +1,10 @@
|
|||
model = dict(
|
||||
type='Classification',
|
||||
type='ImageClassifier',
|
||||
data_preprocessor=dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True,
|
||||
),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -8,7 +13,7 @@ model = dict(
|
|||
norm_cfg=dict(type='BN'),
|
||||
frozen_stages=-1),
|
||||
head=dict(
|
||||
type='MultiClsHead',
|
||||
type='mmselfsup.MultiClsHead',
|
||||
pool_type='specified',
|
||||
in_indices=[0, 1, 2, 3, 4],
|
||||
with_last_layer_unpool=False,
|
||||
|
|
|
@ -4,53 +4,77 @@ _base_ = [
|
|||
'../_base_/schedules/sgd_steplr-100e.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
# Multi-head linear evaluation setting
|
||||
|
||||
model = dict(backbone=dict(frozen_stages=4))
|
||||
# lighting params, in order of BGR
|
||||
EIGVAL = [55.4625, 4.7940, 1.1475]
|
||||
EIGVEC = [
|
||||
[-0.5836, -0.6948, 0.4203],
|
||||
[-0.5808, -0.0045, -0.8140],
|
||||
[-0.5675, 0.7192, 0.4009],
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
# dataset
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Lighting'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(
|
||||
type='Lighting',
|
||||
eigval=EIGVAL,
|
||||
eigvec=EIGVEC,
|
||||
),
|
||||
dict(type='PackClsInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='Resize', size=256),
|
||||
dict(type='CenterCrop', size=224),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs')
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline))
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
|
||||
# MoCo v1/v2 linear evaluation setting
|
||||
model = dict(
|
||||
backbone=dict(out_indices=[0, 1, 2, 3], frozen_stages=4),
|
||||
head=dict(in_indices=[1, 2, 3, 4]))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD',
|
||||
lr=0.01,
|
||||
momentum=0.9,
|
||||
weight_decay=1e-4,
|
||||
paramwise_options=dict(norm_decay_mult=0.),
|
||||
nesterov=True)
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)
|
||||
]
|
||||
type='SGD', lr=0.01, momentum=0.9, weight_decay=1e-4, nesterov=True)
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
paramwise_cfg=dict(norm_decay_mult=0.0))
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=90)
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
|
||||
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
# evaluator
|
||||
val_evaluator = dict(
|
||||
_delete_=True,
|
||||
type='MultiHeadEvaluator',
|
||||
metrics=dict(
|
||||
head1=dict(type='mmcls.Accuracy', topk=(1, 5)),
|
||||
head2=dict(type='mmcls.Accuracy', topk=(1, 5)),
|
||||
head3=dict(type='mmcls.Accuracy', topk=(1, 5)),
|
||||
head4=dict(type='mmcls.Accuracy', topk=(1, 5))))
|
||||
|
||||
# epochs
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=90)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -22,20 +22,29 @@ class MultiClsHead(ClsHead):
|
|||
linear classifier at each stage to predict corresponding class scores.
|
||||
|
||||
Args:
|
||||
backbone (str, optional): Specify which backbone to use, only support
|
||||
backbone (str): Specify which backbone to use, only support
|
||||
ResNet50. Defaults to 'resnet50'.
|
||||
in_indices (Sequence[int], optional): Input from which stages.
|
||||
in_indices (Sequence[int]): Input from which stages.
|
||||
Defaults to (0, 1, 2, 3, 4).
|
||||
pool_type (str, optional): 'adaptive' or 'specified'. If set to
|
||||
pool_type (str): 'adaptive' or 'specified'. If set to
|
||||
'adaptive', use adaptive average pooling, otherwise use specified
|
||||
pooling params. Defaults to 'adaptive'.
|
||||
num_classes (int, optional): Number of classes. Defaults to 1000.
|
||||
loss (Dict, optional): The Dict of loss information. Defaults to
|
||||
'mmcls.models.CrossEntropyLoss'
|
||||
with_last_layer_unpool (bool, optional): Whether to unpool the features
|
||||
num_classes (int): Number of classes. Defaults to 1000.
|
||||
loss (dict): The Dict of loss information. Defaults to
|
||||
'mmcls.models.CrossEntro): Whether to unpool the features
|
||||
from last layer. Defaults to False.
|
||||
norm_cfg (Dict, optional): Dict to construct and config norm layer.
|
||||
init_cfg (Dict or List[Dict], optional): Initialization config dict.
|
||||
cal_acc (bool): Whether to calculate accuracy during training.
|
||||
If you use batch augmentations like Mixup and CutMix during
|
||||
training, it is pointless to calculate accuracy.
|
||||
Defaults to False.
|
||||
topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
|
||||
norm_cfg (dict): Dict to construct and config norm layer.
|
||||
Defaults to ``dict(type='BN')``.
|
||||
init_cfg (dict or List[dict]): Initialization config dict.
|
||||
Defaults to ``[
|
||||
dict(type='Normal', std=0.01, layer='Linear'),
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]``
|
||||
"""
|
||||
|
||||
FEAT_CHANNELS = {'resnet50': [64, 256, 512, 1024, 2048]}
|
||||
|
@ -43,15 +52,16 @@ class MultiClsHead(ClsHead):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: Optional[str] = 'resnet50',
|
||||
in_indices: Optional[Sequence[int]] = (0, 1, 2, 3, 4),
|
||||
pool_type: Optional[str] = 'adaptive',
|
||||
num_classes: Optional[int] = 1000,
|
||||
loss: Optional[Dict] = dict(
|
||||
type='mmcls.CrossEntropyLoss', loss_weight=1.0),
|
||||
with_last_layer_unpool: Optional[bool] = False,
|
||||
norm_cfg: Optional[Dict] = dict(type='BN'),
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
||||
backbone: str = 'resnet50',
|
||||
in_indices: Sequence[int] = (0, 1, 2, 3, 4),
|
||||
pool_type: str = 'adaptive',
|
||||
num_classes: int = 1000,
|
||||
loss: dict = dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
|
||||
with_last_layer_unpool: bool = False,
|
||||
cal_acc: bool = False,
|
||||
topk: Union[int, Tuple[int]] = (1, ),
|
||||
norm_cfg: dict = dict(type='BN'),
|
||||
init_cfg: Union[Dict, List[Dict]] = [
|
||||
dict(type='Normal', std=0.01, layer='Linear'),
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
|
@ -80,10 +90,12 @@ class MultiClsHead(ClsHead):
|
|||
for i in in_indices
|
||||
])
|
||||
|
||||
self.cal_acc = cal_acc
|
||||
self.topk = topk
|
||||
# build loss
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def forward(self, feats):
|
||||
def forward(self, feats: Union[list, tuple]) -> list:
|
||||
"""Compute multi-head scores.
|
||||
|
||||
Args:
|
||||
|
@ -96,6 +108,7 @@ class MultiClsHead(ClsHead):
|
|||
assert isinstance(feats, (list, tuple))
|
||||
if self.with_last_layer_unpool:
|
||||
last_feats = feats[-1]
|
||||
|
||||
feats = self.multi_pooling(feats)
|
||||
|
||||
if self.with_norm:
|
||||
|
@ -109,7 +122,7 @@ class MultiClsHead(ClsHead):
|
|||
return cls_score
|
||||
|
||||
def loss(self, feats: Sequence[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> Dict:
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
"""Calculate losses from the extracted features.
|
||||
|
||||
Args:
|
||||
|
@ -131,11 +144,17 @@ class MultiClsHead(ClsHead):
|
|||
target = torch.hstack(
|
||||
[data_sample.gt_label.label for data_sample in data_samples])
|
||||
|
||||
# compute loss
|
||||
# compute loss and accuracy
|
||||
losses = dict()
|
||||
for i, score in zip(self.in_indices, cls_score):
|
||||
losses[f'loss.{i + 1}'] = self.loss_module(score, target)
|
||||
losses[f'accuracy.{i + 1}'] = Accuracy.calculate(score, target)
|
||||
if self.cal_acc:
|
||||
acc = Accuracy.calculate(score, target, topk=self.topk)
|
||||
losses.update({
|
||||
f'accuracy.{i+1}.top-{k}': a
|
||||
for k, a in zip(self.topk, acc)
|
||||
})
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, feats: Sequence[torch.Tensor],
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
|
||||
class MultiPooling(BaseModule):
|
||||
|
@ -27,12 +29,13 @@ class MultiPooling(BaseModule):
|
|||
POOL_DIMS = {'resnet50': [9216, 9216, 8192, 9216, 8192]}
|
||||
|
||||
def __init__(self,
|
||||
pool_type='adaptive',
|
||||
in_indices=(0, ),
|
||||
backbone='resnet50'):
|
||||
super(MultiPooling, self).__init__()
|
||||
pool_type: str = 'adaptive',
|
||||
in_indices: tuple = (0, ),
|
||||
backbone: str = 'resnet50') -> None:
|
||||
super().__init__()
|
||||
assert pool_type in ['adaptive', 'specified']
|
||||
assert backbone == 'resnet50', 'Now only support resnet50.'
|
||||
|
||||
if pool_type == 'adaptive':
|
||||
self.pools = nn.ModuleList([
|
||||
nn.AdaptiveAvgPool2d(self.POOL_SIZES[backbone][i])
|
||||
|
@ -44,6 +47,6 @@ class MultiPooling(BaseModule):
|
|||
for i in in_indices
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Union[List, Tuple]) -> None:
|
||||
assert isinstance(x, (list, tuple))
|
||||
return [p(xx) for p, xx in zip(self.pools, x)]
|
||||
|
|
|
@ -18,7 +18,7 @@ class TestMultiClsHead(TestCase):
|
|||
fake_data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
losses = head.loss(fake_in, fake_data_samples)
|
||||
print(losses)
|
||||
self.assertEqual(len(losses.keys()), 4)
|
||||
self.assertEqual(len(losses.keys()), 2)
|
||||
for k in losses.keys():
|
||||
assert k.startswith('loss') or k.startswith('accuracy')
|
||||
if k.startswith('loss'):
|
||||
|
|
Loading…
Reference in New Issue