mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
1. Use original config as startup script. (For details, see refactor config parsing method #225) 2. Refactor the splicing rules of the check_base_cfg_path function in the EasyCV/easycv/utils/config_tools.py 3. Support three ways to pass class_list parameter. 4. Fix the bug that clsevalutor may make mistakes when evaluating top5. 5. Fix the bug that the distributed export cannot export the model. 6. Fix the bug that the load pretrained model key does not match. 7. support cls data source itag.
308 lines
11 KiB
Python
308 lines
11 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from typing import Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.runner import get_dist_info
|
|
from timm.data.mixup import Mixup
|
|
|
|
from easycv.framework.errors import KeyError, NotImplementedError, ValueError
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.logger import get_root_logger, print_log
|
|
from easycv.utils.preprocess_function import (bninceptionPre, gaussianBlur,
|
|
mixUpCls, randomErasing)
|
|
from .. import builder
|
|
from ..base import BaseModel
|
|
from ..registry import MODELS
|
|
from ..utils import Sobel
|
|
|
|
|
|
@MODELS.register_module
|
|
class Classification(BaseModel):
|
|
"""
|
|
Args:
|
|
pretrained: Select one {str or True or False/None}.
|
|
if pretrained == str, load model from specified path;
|
|
if pretrained == True, load model from default path(currently only supports timm);
|
|
if pretrained == False or None, load from init weights.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
train_preprocess=[],
|
|
with_sobel=False,
|
|
head=None,
|
|
neck=None,
|
|
pretrained=True,
|
|
mixup_cfg=None):
|
|
super(Classification, self).__init__()
|
|
self.with_sobel = with_sobel
|
|
self.pretrained = pretrained
|
|
if with_sobel:
|
|
self.sobel_layer = Sobel()
|
|
else:
|
|
self.sobel_layer = None
|
|
|
|
self.preprocess_key_map = {
|
|
'bninceptionPre': bninceptionPre,
|
|
'gaussianBlur': gaussianBlur,
|
|
'mixUpCls': mixUpCls,
|
|
'randomErasing': randomErasing
|
|
}
|
|
|
|
if 'mixUp' in train_preprocess:
|
|
rank, _ = get_dist_info()
|
|
np.random.seed(rank + 12)
|
|
if mixup_cfg is not None:
|
|
if 'num_classes' in mixup_cfg:
|
|
self.mixup = Mixup(**mixup_cfg)
|
|
elif 'num_classes' in head or 'num_classes' in backbone:
|
|
num_classes = head.get(
|
|
'num_classes'
|
|
) if 'num_classes' in head else backbone.get('num_classes')
|
|
mixup_cfg['num_classes'] = num_classes
|
|
self.mixup = Mixup(**mixup_cfg)
|
|
train_preprocess.remove('mixUp')
|
|
self.train_preprocess = [
|
|
self.preprocess_key_map[i] for i in train_preprocess
|
|
]
|
|
|
|
self.backbone = builder.build_backbone(backbone)
|
|
|
|
assert head is not None, 'Classification head should be configed'
|
|
|
|
if type(head) == list:
|
|
self.head_num = len(head)
|
|
tmp_head_list = [builder.build_head(h) for h in head]
|
|
else:
|
|
self.head_num = 1
|
|
tmp_head_list = [builder.build_head(head)]
|
|
|
|
# do this setattr to make sure nn.Module to be attr of nn.Module
|
|
for idx, h in enumerate(tmp_head_list):
|
|
setattr(self, 'head_%d' % idx, h)
|
|
|
|
if type(neck) == list:
|
|
self.neck_num = len(neck)
|
|
tmp_neck_list = [builder.build_neck(n) for n in neck]
|
|
elif neck is not None:
|
|
self.neck_num = 1
|
|
tmp_neck_list = [builder.build_neck(neck)]
|
|
else:
|
|
self.neck_num = 0
|
|
tmp_neck_list = []
|
|
|
|
# do this setattr to make sure nn.Module to be attr of nn.Module
|
|
for idx, n in enumerate(tmp_neck_list):
|
|
setattr(self, 'neck_%d' % idx, n)
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.activate_fn = nn.Softmax(dim=1)
|
|
self.extract_list = ['neck']
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
logger = get_root_logger()
|
|
if isinstance(self.pretrained, str):
|
|
load_checkpoint(
|
|
self.backbone, self.pretrained, strict=False, logger=logger)
|
|
elif self.pretrained:
|
|
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
|
|
self.backbone.init_weights(pretrained=self.pretrained)
|
|
elif hasattr(self.backbone, 'default_pretrained_model_path'
|
|
) and self.backbone.default_pretrained_model_path:
|
|
print_log(
|
|
'load model from default path: {}'.format(
|
|
self.backbone.default_pretrained_model_path), logger)
|
|
load_checkpoint(
|
|
self.backbone,
|
|
self.backbone.default_pretrained_model_path,
|
|
strict=False,
|
|
logger=logger,
|
|
revise_keys=[
|
|
(r'^backbone\.', '')
|
|
]) # revise_keys is used to avoid load mismatches
|
|
else:
|
|
raise ValueError(
|
|
'default_pretrained_model_path for {} not found'.format(
|
|
self.backbone.__class__.__name__))
|
|
else:
|
|
print_log('load model from init weights')
|
|
self.backbone.init_weights()
|
|
|
|
for idx in range(self.head_num):
|
|
h = getattr(self, 'head_%d' % idx)
|
|
h.init_weights()
|
|
|
|
for idx in range(self.neck_num):
|
|
n = getattr(self, 'neck_%d' % idx)
|
|
n.init_weights()
|
|
|
|
def forward_backbone(self, img: torch.Tensor) -> List[torch.Tensor]:
|
|
"""Forward backbone
|
|
|
|
Returns:
|
|
x (tuple): backbone outputs
|
|
"""
|
|
if self.sobel_layer is not None:
|
|
img = self.sobel_layer(img)
|
|
x = self.backbone(img)
|
|
return x
|
|
|
|
@torch.jit.unused
|
|
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
In forward train, model will forward backbone + neck / multi-neck, get alist of output tensor,
|
|
and put this list to head / multi-head, to compute each loss
|
|
"""
|
|
# for mxk sampler, use datasource type = ClsSourceImageListByClass will sample k img in 1 class,
|
|
# input data will be m x k x c x h x w, should be reshape to (m x k) x c x h x w
|
|
if img.dim() == 5:
|
|
new_shape = [
|
|
img.shape[0] * img.shape[1], img.shape[2], img.shape[3],
|
|
img.shape[4]
|
|
]
|
|
img = img.view(new_shape)
|
|
gt_labels = gt_labels.view([-1])
|
|
|
|
for preprocess in self.train_preprocess:
|
|
img = preprocess(img)
|
|
|
|
# When the number of samples in the dataset is odd, the last batch size of each epoch will be odd,
|
|
# which will cause mixup to report an error. To avoid this situation, mixup is applied only when
|
|
# the batch size is even.
|
|
if hasattr(self, 'mixup') and len(img) % 2 == 0:
|
|
img, gt_labels = self.mixup(img, gt_labels)
|
|
|
|
x = self.forward_backbone(img)
|
|
|
|
if self.neck_num > 0:
|
|
tmp = []
|
|
for idx in range(self.neck_num):
|
|
h = getattr(self, 'neck_%d' % idx)
|
|
tmp += h(x)
|
|
x = tmp
|
|
else:
|
|
x = x
|
|
|
|
losses = {}
|
|
for idx in range(self.head_num):
|
|
h = getattr(self, 'head_%d' % idx)
|
|
outs = h(x)
|
|
loss_inputs = (outs, gt_labels)
|
|
hlosses = h.loss(*loss_inputs)
|
|
if 'loss' in losses.keys():
|
|
losses['loss'] += hlosses['loss']
|
|
else:
|
|
losses['loss'] = hlosses['loss']
|
|
|
|
return losses
|
|
|
|
# @torch.jit.unused
|
|
def forward_test(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
forward_test means generate prob/class from image only support one neck + one head
|
|
"""
|
|
x = self.forward_backbone(img) # tuple
|
|
|
|
# if self.neck_num > 0:
|
|
if hasattr(self, 'neck_0'):
|
|
x = self.neck_0([i for i in x])
|
|
|
|
out = self.head_0(x)[0].cpu()
|
|
result = {}
|
|
result['prob'] = self.activate_fn(out)
|
|
result['class'] = torch.argmax(result['prob'])
|
|
return result
|
|
|
|
@torch.jit.unused
|
|
def forward_test_label(self, img, gt_labels) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
forward_test_label means generate prob/class from image only support one neck + one head
|
|
ps : head init need set the input feature idx
|
|
"""
|
|
x = self.forward_backbone(img) # tuple
|
|
|
|
if hasattr(self, 'neck_0'):
|
|
x = self.neck_0([i for i in x])
|
|
|
|
out = [self.head_0(x)[0].cpu()]
|
|
keys = ['neck']
|
|
keys.append('gt_labels')
|
|
out.append(gt_labels.cpu())
|
|
return dict(zip(keys, out))
|
|
|
|
def aug_test(self, imgs):
|
|
raise NotImplementedError
|
|
|
|
def forward_feature(self, img) -> Dict[str, torch.Tensor]:
|
|
"""Forward feature means forward backbone + neck/multineck ,get dict of output feature,
|
|
self.neck_num = 0: means only forward backbone, output backbone feature with avgpool, with key neck,
|
|
self.neck_num > 0: means has 1/multi neck, output neck's feature with key neck_neckidx_featureidx, suck as neck_0_0
|
|
Returns:
|
|
x (torch.Tensor): feature tensor
|
|
"""
|
|
return_dict = {}
|
|
x = self.backbone(img)
|
|
# return_dict['backbone'] = x[-1]
|
|
if hasattr(self, 'neck_0'):
|
|
tmp = []
|
|
for idx in range(self.neck_num):
|
|
neck_name = 'neck_%d' % idx
|
|
h = getattr(self, neck_name)
|
|
neck_h = h([i for i in x])
|
|
tmp = tmp + neck_h
|
|
for j in range(len(neck_h)):
|
|
neck_name = 'neck_%d_%d' % (idx, j)
|
|
return_dict['neck_%d_%d' % (idx, j)] = neck_h[j]
|
|
if neck_name not in self.extract_list:
|
|
self.extract_list.append(neck_name)
|
|
return_dict['neck'] = tmp[0]
|
|
else:
|
|
feature = self.avg_pool(x[-1])
|
|
feature = feature.view(feature.size(0), -1)
|
|
return_dict['neck'] = feature
|
|
|
|
return return_dict
|
|
|
|
def update_extract_list(self, key):
|
|
if key not in self.extract_list:
|
|
self.extract_list.append(key)
|
|
return
|
|
|
|
def forward(
|
|
self,
|
|
img: torch.Tensor,
|
|
mode: str = 'train',
|
|
gt_labels: Optional[torch.Tensor] = None,
|
|
img_metas: Optional[torch.Tensor] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
# TODO: support Dict[any, any] type of img_metas
|
|
del img_metas # fake img_metas for support jit
|
|
if mode == 'train':
|
|
assert gt_labels is not None
|
|
return self.forward_train(img, gt_labels)
|
|
elif mode == 'test':
|
|
if gt_labels is not None:
|
|
return self.forward_test_label(img, gt_labels)
|
|
else:
|
|
return self.forward_test(img)
|
|
elif mode == 'extract':
|
|
rd = self.forward_feature(img)
|
|
rv = {}
|
|
for name in self.extract_list:
|
|
if name in rd.keys():
|
|
rv[name] = rd[name].cpu()
|
|
else:
|
|
raise ValueError(
|
|
'Extract {} is not support in classification models'.
|
|
format(name))
|
|
if gt_labels is not None:
|
|
rv['gt_labels'] = gt_labels.cpu()
|
|
return rv
|
|
else:
|
|
raise KeyError('No such mode: {}'.format(mode))
|