EasyCV/easycv/models/classification/classification.py
tuofeilun 23f2b0e399
Adapt designer (#235)
1. Use original config as startup script. (For details, see refactor config parsing method #225)
2. Refactor the splicing rules of the check_base_cfg_path function in the EasyCV/easycv/utils/config_tools.py
3. Support three ways to pass class_list parameter.
4. Fix the bug that clsevalutor may make mistakes when evaluating top5.
5. Fix the bug that the distributed export cannot export the model.
6. Fix the bug that the load pretrained model key does not match.
7. support cls data source itag.
2022-12-01 17:47:10 +08:00

308 lines
11 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import get_dist_info
from timm.data.mixup import Mixup
from easycv.framework.errors import KeyError, NotImplementedError, ValueError
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger, print_log
from easycv.utils.preprocess_function import (bninceptionPre, gaussianBlur,
mixUpCls, randomErasing)
from .. import builder
from ..base import BaseModel
from ..registry import MODELS
from ..utils import Sobel
@MODELS.register_module
class Classification(BaseModel):
"""
Args:
pretrained: Select one {str or True or False/None}.
if pretrained == str, load model from specified path;
if pretrained == True, load model from default path(currently only supports timm);
if pretrained == False or None, load from init weights.
"""
def __init__(self,
backbone,
train_preprocess=[],
with_sobel=False,
head=None,
neck=None,
pretrained=True,
mixup_cfg=None):
super(Classification, self).__init__()
self.with_sobel = with_sobel
self.pretrained = pretrained
if with_sobel:
self.sobel_layer = Sobel()
else:
self.sobel_layer = None
self.preprocess_key_map = {
'bninceptionPre': bninceptionPre,
'gaussianBlur': gaussianBlur,
'mixUpCls': mixUpCls,
'randomErasing': randomErasing
}
if 'mixUp' in train_preprocess:
rank, _ = get_dist_info()
np.random.seed(rank + 12)
if mixup_cfg is not None:
if 'num_classes' in mixup_cfg:
self.mixup = Mixup(**mixup_cfg)
elif 'num_classes' in head or 'num_classes' in backbone:
num_classes = head.get(
'num_classes'
) if 'num_classes' in head else backbone.get('num_classes')
mixup_cfg['num_classes'] = num_classes
self.mixup = Mixup(**mixup_cfg)
train_preprocess.remove('mixUp')
self.train_preprocess = [
self.preprocess_key_map[i] for i in train_preprocess
]
self.backbone = builder.build_backbone(backbone)
assert head is not None, 'Classification head should be configed'
if type(head) == list:
self.head_num = len(head)
tmp_head_list = [builder.build_head(h) for h in head]
else:
self.head_num = 1
tmp_head_list = [builder.build_head(head)]
# do this setattr to make sure nn.Module to be attr of nn.Module
for idx, h in enumerate(tmp_head_list):
setattr(self, 'head_%d' % idx, h)
if type(neck) == list:
self.neck_num = len(neck)
tmp_neck_list = [builder.build_neck(n) for n in neck]
elif neck is not None:
self.neck_num = 1
tmp_neck_list = [builder.build_neck(neck)]
else:
self.neck_num = 0
tmp_neck_list = []
# do this setattr to make sure nn.Module to be attr of nn.Module
for idx, n in enumerate(tmp_neck_list):
setattr(self, 'neck_%d' % idx, n)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.activate_fn = nn.Softmax(dim=1)
self.extract_list = ['neck']
self.init_weights()
def init_weights(self):
logger = get_root_logger()
if isinstance(self.pretrained, str):
load_checkpoint(
self.backbone, self.pretrained, strict=False, logger=logger)
elif self.pretrained:
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
self.backbone.init_weights(pretrained=self.pretrained)
elif hasattr(self.backbone, 'default_pretrained_model_path'
) and self.backbone.default_pretrained_model_path:
print_log(
'load model from default path: {}'.format(
self.backbone.default_pretrained_model_path), logger)
load_checkpoint(
self.backbone,
self.backbone.default_pretrained_model_path,
strict=False,
logger=logger,
revise_keys=[
(r'^backbone\.', '')
]) # revise_keys is used to avoid load mismatches
else:
raise ValueError(
'default_pretrained_model_path for {} not found'.format(
self.backbone.__class__.__name__))
else:
print_log('load model from init weights')
self.backbone.init_weights()
for idx in range(self.head_num):
h = getattr(self, 'head_%d' % idx)
h.init_weights()
for idx in range(self.neck_num):
n = getattr(self, 'neck_%d' % idx)
n.init_weights()
def forward_backbone(self, img: torch.Tensor) -> List[torch.Tensor]:
"""Forward backbone
Returns:
x (tuple): backbone outputs
"""
if self.sobel_layer is not None:
img = self.sobel_layer(img)
x = self.backbone(img)
return x
@torch.jit.unused
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
"""
In forward train, model will forward backbone + neck / multi-neck, get alist of output tensor,
and put this list to head / multi-head, to compute each loss
"""
# for mxk sampler, use datasource type = ClsSourceImageListByClass will sample k img in 1 class,
# input data will be m x k x c x h x w, should be reshape to (m x k) x c x h x w
if img.dim() == 5:
new_shape = [
img.shape[0] * img.shape[1], img.shape[2], img.shape[3],
img.shape[4]
]
img = img.view(new_shape)
gt_labels = gt_labels.view([-1])
for preprocess in self.train_preprocess:
img = preprocess(img)
# When the number of samples in the dataset is odd, the last batch size of each epoch will be odd,
# which will cause mixup to report an error. To avoid this situation, mixup is applied only when
# the batch size is even.
if hasattr(self, 'mixup') and len(img) % 2 == 0:
img, gt_labels = self.mixup(img, gt_labels)
x = self.forward_backbone(img)
if self.neck_num > 0:
tmp = []
for idx in range(self.neck_num):
h = getattr(self, 'neck_%d' % idx)
tmp += h(x)
x = tmp
else:
x = x
losses = {}
for idx in range(self.head_num):
h = getattr(self, 'head_%d' % idx)
outs = h(x)
loss_inputs = (outs, gt_labels)
hlosses = h.loss(*loss_inputs)
if 'loss' in losses.keys():
losses['loss'] += hlosses['loss']
else:
losses['loss'] = hlosses['loss']
return losses
# @torch.jit.unused
def forward_test(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
forward_test means generate prob/class from image only support one neck + one head
"""
x = self.forward_backbone(img) # tuple
# if self.neck_num > 0:
if hasattr(self, 'neck_0'):
x = self.neck_0([i for i in x])
out = self.head_0(x)[0].cpu()
result = {}
result['prob'] = self.activate_fn(out)
result['class'] = torch.argmax(result['prob'])
return result
@torch.jit.unused
def forward_test_label(self, img, gt_labels) -> Dict[str, torch.Tensor]:
"""
forward_test_label means generate prob/class from image only support one neck + one head
ps : head init need set the input feature idx
"""
x = self.forward_backbone(img) # tuple
if hasattr(self, 'neck_0'):
x = self.neck_0([i for i in x])
out = [self.head_0(x)[0].cpu()]
keys = ['neck']
keys.append('gt_labels')
out.append(gt_labels.cpu())
return dict(zip(keys, out))
def aug_test(self, imgs):
raise NotImplementedError
def forward_feature(self, img) -> Dict[str, torch.Tensor]:
"""Forward feature means forward backbone + neck/multineck ,get dict of output feature,
self.neck_num = 0: means only forward backbone, output backbone feature with avgpool, with key neck,
self.neck_num > 0: means has 1/multi neck, output neck's feature with key neck_neckidx_featureidx, suck as neck_0_0
Returns:
x (torch.Tensor): feature tensor
"""
return_dict = {}
x = self.backbone(img)
# return_dict['backbone'] = x[-1]
if hasattr(self, 'neck_0'):
tmp = []
for idx in range(self.neck_num):
neck_name = 'neck_%d' % idx
h = getattr(self, neck_name)
neck_h = h([i for i in x])
tmp = tmp + neck_h
for j in range(len(neck_h)):
neck_name = 'neck_%d_%d' % (idx, j)
return_dict['neck_%d_%d' % (idx, j)] = neck_h[j]
if neck_name not in self.extract_list:
self.extract_list.append(neck_name)
return_dict['neck'] = tmp[0]
else:
feature = self.avg_pool(x[-1])
feature = feature.view(feature.size(0), -1)
return_dict['neck'] = feature
return return_dict
def update_extract_list(self, key):
if key not in self.extract_list:
self.extract_list.append(key)
return
def forward(
self,
img: torch.Tensor,
mode: str = 'train',
gt_labels: Optional[torch.Tensor] = None,
img_metas: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
# TODO: support Dict[any, any] type of img_metas
del img_metas # fake img_metas for support jit
if mode == 'train':
assert gt_labels is not None
return self.forward_train(img, gt_labels)
elif mode == 'test':
if gt_labels is not None:
return self.forward_test_label(img, gt_labels)
else:
return self.forward_test(img)
elif mode == 'extract':
rd = self.forward_feature(img)
rv = {}
for name in self.extract_list:
if name in rd.keys():
rv[name] = rd[name].cpu()
else:
raise ValueError(
'Extract {} is not support in classification models'.
format(name))
if gt_labels is not None:
rv['gt_labels'] = gt_labels.cpu()
return rv
else:
raise KeyError('No such mode: {}'.format(mode))