mirror of https://github.com/alibaba/EasyCV.git
325 lines
12 KiB
Python
325 lines
12 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from typing import Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.runner import get_dist_info
|
|
from timm.data.mixup import Mixup
|
|
|
|
from easycv.framework.errors import KeyError, NotImplementedError, ValueError
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.logger import get_root_logger, print_log
|
|
from easycv.utils.preprocess_function import (bninceptionPre, gaussianBlur,
|
|
mixUpCls, randomErasing)
|
|
from .. import builder
|
|
from ..base import BaseModel
|
|
from ..registry import MODELS
|
|
from ..utils import Sobel
|
|
|
|
|
|
@MODELS.register_module
|
|
class Classification(BaseModel):
|
|
"""
|
|
Args:
|
|
pretrained: Select one {str or True or False/None}.
|
|
if pretrained == str, load model from specified path;
|
|
if pretrained == True, load model from default path(currently only supports timm);
|
|
if pretrained == False or None, load from init weights.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
train_preprocess=[],
|
|
with_sobel=False,
|
|
head=None,
|
|
neck=None,
|
|
pretrained=True,
|
|
mixup_cfg=None):
|
|
super(Classification, self).__init__()
|
|
self.with_sobel = with_sobel
|
|
self.pretrained = pretrained
|
|
if with_sobel:
|
|
self.sobel_layer = Sobel()
|
|
else:
|
|
self.sobel_layer = None
|
|
|
|
self.preprocess_key_map = {
|
|
'bninceptionPre': bninceptionPre,
|
|
'gaussianBlur': gaussianBlur,
|
|
'mixUpCls': mixUpCls,
|
|
'randomErasing': randomErasing
|
|
}
|
|
|
|
if 'mixUp' in train_preprocess:
|
|
rank, _ = get_dist_info()
|
|
np.random.seed(rank + 12)
|
|
if mixup_cfg is not None:
|
|
if 'num_classes' in mixup_cfg:
|
|
self.mixup = Mixup(**mixup_cfg)
|
|
elif 'num_classes' in head or 'num_classes' in backbone:
|
|
num_classes = head.get(
|
|
'num_classes'
|
|
) if 'num_classes' in head else backbone.get('num_classes')
|
|
mixup_cfg['num_classes'] = num_classes
|
|
self.mixup = Mixup(**mixup_cfg)
|
|
train_preprocess.remove('mixUp')
|
|
self.train_preprocess = [
|
|
self.preprocess_key_map[i] for i in train_preprocess
|
|
]
|
|
|
|
self.backbone = builder.build_backbone(backbone)
|
|
|
|
assert head is not None, 'Classification head should be configed'
|
|
|
|
if type(head) == list:
|
|
self.head_num = len(head)
|
|
tmp_head_list = [builder.build_head(h) for h in head]
|
|
else:
|
|
self.head_num = 1
|
|
tmp_head_list = [builder.build_head(head)]
|
|
|
|
# do this setattr to make sure nn.Module to be attr of nn.Module
|
|
for idx, h in enumerate(tmp_head_list):
|
|
setattr(self, 'head_%d' % idx, h)
|
|
|
|
if type(neck) == list:
|
|
self.neck_num = len(neck)
|
|
tmp_neck_list = [builder.build_neck(n) for n in neck]
|
|
elif neck is not None:
|
|
self.neck_num = 1
|
|
tmp_neck_list = [builder.build_neck(neck)]
|
|
else:
|
|
self.neck_num = 0
|
|
tmp_neck_list = []
|
|
|
|
# do this setattr to make sure nn.Module to be attr of nn.Module
|
|
for idx, n in enumerate(tmp_neck_list):
|
|
setattr(self, 'neck_%d' % idx, n)
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.activate_fn = nn.Softmax(dim=1)
|
|
self.extract_list = ['neck']
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
logger = get_root_logger()
|
|
if isinstance(self.pretrained, str):
|
|
load_checkpoint(
|
|
self.backbone, self.pretrained, strict=False, logger=logger)
|
|
elif self.pretrained:
|
|
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
|
|
self.backbone.init_weights(pretrained=self.pretrained)
|
|
elif hasattr(self.backbone, 'default_pretrained_model_path'
|
|
) and self.backbone.default_pretrained_model_path:
|
|
print_log(
|
|
'load model from default path: {}'.format(
|
|
self.backbone.default_pretrained_model_path), logger)
|
|
load_checkpoint(
|
|
self.backbone,
|
|
self.backbone.default_pretrained_model_path,
|
|
strict=False,
|
|
logger=logger,
|
|
revise_keys=[
|
|
(r'^backbone\.', '')
|
|
]) # revise_keys is used to avoid load mismatches
|
|
else:
|
|
raise ValueError(
|
|
'default_pretrained_model_path for {} not found'.format(
|
|
self.backbone.__class__.__name__))
|
|
else:
|
|
print_log('load model from init weights')
|
|
self.backbone.init_weights()
|
|
|
|
for idx in range(self.head_num):
|
|
h = getattr(self, 'head_%d' % idx)
|
|
h.init_weights()
|
|
|
|
for idx in range(self.neck_num):
|
|
n = getattr(self, 'neck_%d' % idx)
|
|
n.init_weights()
|
|
|
|
def forward_backbone(self, img: torch.Tensor) -> List[torch.Tensor]:
|
|
"""Forward backbone
|
|
|
|
Returns:
|
|
x (tuple): backbone outputs
|
|
"""
|
|
if self.sobel_layer is not None:
|
|
img = self.sobel_layer(img)
|
|
x = self.backbone(img)
|
|
return x
|
|
|
|
def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
forward_onnx means generate prob from image only support one neck + one head
|
|
"""
|
|
x = self.forward_backbone(img) # tuple
|
|
|
|
# if self.neck_num > 0:
|
|
if hasattr(self, 'neck_0'):
|
|
x = self.neck_0([i for i in x])
|
|
|
|
out = self.head_0(x)[0].cpu()
|
|
out = self.activate_fn(out)
|
|
return out
|
|
|
|
@torch.jit.unused
|
|
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
In forward train, model will forward backbone + neck / multi-neck, get alist of output tensor,
|
|
and put this list to head / multi-head, to compute each loss
|
|
"""
|
|
# for mxk sampler, use datasource type = ClsSourceImageListByClass will sample k img in 1 class,
|
|
# input data will be m x k x c x h x w, should be reshape to (m x k) x c x h x w
|
|
if img.dim() == 5:
|
|
new_shape = [
|
|
img.shape[0] * img.shape[1], img.shape[2], img.shape[3],
|
|
img.shape[4]
|
|
]
|
|
img = img.view(new_shape)
|
|
gt_labels = gt_labels.view([-1])
|
|
|
|
for preprocess in self.train_preprocess:
|
|
img = preprocess(img)
|
|
|
|
# When the number of samples in the dataset is odd, the last batch size of each epoch will be odd,
|
|
# which will cause mixup to report an error. To avoid this situation, mixup is applied only when
|
|
# the batch size is even.
|
|
if hasattr(self, 'mixup') and len(img) % 2 == 0:
|
|
img, gt_labels = self.mixup(img, gt_labels)
|
|
|
|
x = self.forward_backbone(img)
|
|
|
|
if self.neck_num > 0:
|
|
tmp = []
|
|
for idx in range(self.neck_num):
|
|
h = getattr(self, 'neck_%d' % idx)
|
|
tmp += h(x)
|
|
x = tmp
|
|
else:
|
|
x = x
|
|
|
|
losses = {}
|
|
for idx in range(self.head_num):
|
|
h = getattr(self, 'head_%d' % idx)
|
|
outs = h(x)
|
|
loss_inputs = (outs, gt_labels)
|
|
hlosses = h.loss(*loss_inputs)
|
|
if 'loss' in losses.keys():
|
|
losses['loss'] += hlosses['loss']
|
|
else:
|
|
losses['loss'] = hlosses['loss']
|
|
|
|
return losses
|
|
|
|
# @torch.jit.unused
|
|
def forward_test(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
forward_test means generate prob/class from image only support one neck + one head
|
|
"""
|
|
x = self.forward_backbone(img) # tuple
|
|
|
|
# if self.neck_num > 0:
|
|
if hasattr(self, 'neck_0'):
|
|
x = self.neck_0([i for i in x])
|
|
|
|
out = self.head_0(x)[0].cpu()
|
|
result = {}
|
|
result['prob'] = self.activate_fn(out)
|
|
result['class'] = torch.argmax(result['prob'])
|
|
return result
|
|
|
|
@torch.jit.unused
|
|
def forward_test_label(self, img, gt_labels) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
forward_test_label means generate prob/class from image only support one neck + one head
|
|
ps : head init need set the input feature idx
|
|
"""
|
|
x = self.forward_backbone(img) # tuple
|
|
|
|
if hasattr(self, 'neck_0'):
|
|
x = self.neck_0([i for i in x])
|
|
|
|
out = [self.head_0(x)[0].cpu()]
|
|
keys = ['neck']
|
|
keys.append('gt_labels')
|
|
out.append(gt_labels.cpu())
|
|
return dict(zip(keys, out))
|
|
|
|
def aug_test(self, imgs):
|
|
raise NotImplementedError
|
|
|
|
def forward_feature(self, img) -> Dict[str, torch.Tensor]:
|
|
"""Forward feature means forward backbone + neck/multineck ,get dict of output feature,
|
|
self.neck_num = 0: means only forward backbone, output backbone feature with avgpool, with key neck,
|
|
self.neck_num > 0: means has 1/multi neck, output neck's feature with key neck_neckidx_featureidx, suck as neck_0_0
|
|
Returns:
|
|
x (torch.Tensor): feature tensor
|
|
"""
|
|
return_dict = {}
|
|
x = self.backbone(img)
|
|
# return_dict['backbone'] = x[-1]
|
|
if hasattr(self, 'neck_0'):
|
|
tmp = []
|
|
for idx in range(self.neck_num):
|
|
neck_name = 'neck_%d' % idx
|
|
h = getattr(self, neck_name)
|
|
neck_h = h([i for i in x])
|
|
tmp = tmp + neck_h
|
|
for j in range(len(neck_h)):
|
|
neck_name = 'neck_%d_%d' % (idx, j)
|
|
return_dict['neck_%d_%d' % (idx, j)] = neck_h[j]
|
|
if neck_name not in self.extract_list:
|
|
self.extract_list.append(neck_name)
|
|
return_dict['neck'] = tmp[0]
|
|
else:
|
|
feature = self.avg_pool(x[-1])
|
|
feature = feature.view(feature.size(0), -1)
|
|
return_dict['neck'] = feature
|
|
|
|
return return_dict
|
|
|
|
def update_extract_list(self, key):
|
|
if key not in self.extract_list:
|
|
self.extract_list.append(key)
|
|
return
|
|
|
|
def forward(
|
|
self,
|
|
img: torch.Tensor,
|
|
mode: str = 'train',
|
|
gt_labels: Optional[torch.Tensor] = None,
|
|
img_metas: Optional[torch.Tensor] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
# TODO: support Dict[any, any] type of img_metas
|
|
del img_metas # fake img_metas for support jit
|
|
if mode == 'train':
|
|
assert gt_labels is not None
|
|
return self.forward_train(img, gt_labels)
|
|
elif mode == 'test':
|
|
if gt_labels is not None:
|
|
return self.forward_test_label(img, gt_labels)
|
|
else:
|
|
return self.forward_test(img)
|
|
elif mode == 'onnx':
|
|
return self.forward_onnx(img)
|
|
|
|
elif mode == 'extract':
|
|
rd = self.forward_feature(img)
|
|
rv = {}
|
|
for name in self.extract_list:
|
|
if name in rd.keys():
|
|
rv[name] = rd[name].cpu()
|
|
else:
|
|
raise ValueError(
|
|
'Extract {} is not support in classification models'.
|
|
format(name))
|
|
if gt_labels is not None:
|
|
rv['gt_labels'] = gt_labels.cpu()
|
|
return rv
|
|
else:
|
|
raise KeyError('No such mode: {}'.format(mode))
|